import os, csv, time from argparse import ArgumentParser import numpy as np from transformers import AutoModelForCausalLM from anticipation import ops from anticipation.visuals import visualize from anticipation.sample import generate from anticipation.convert import midi_to_events, events_to_midi np.random.seed(0) def main(args): print(f'Prompting using model checkpoint: {args.model}') t0 = time.time() model = AutoModelForCausalLM.from_pretrained(args.model).cuda() print(f'Loaded model ({time.time()-t0} seconds)') print(f'Writing outputs to {args.dir}/{args.output}') try: os.makedirs(f'{args.dir}/{args.output}') except FileExistsError: pass print(f'Prompting with tracks in index : {args.dir}/index.csv') with open(f'{args.dir}/index.csv', newline='') as f: reader = csv.reader(f) header = next(reader) for row in reader: prompt_midi = row[header.index('prompt')] idx = int(row[header.index('idx')]) if idx < args.offset: continue if idx >= args.offset+args.count: break prompt = midi_to_events(os.path.join(args.dir, prompt_midi)) start_time = ops.max_time(prompt) for j in range(args.multiplicity): t0 = time.time() generated_tokens = generate(model, start_time, args.clip_length, prompt, controls=[], top_p=0.98) output = ops.clip(generated_tokens, 0, args.clip_length) mid = events_to_midi(output) mid.save(f'{args.dir}/{args.output}/{idx}-clip-v{j}.mid') if args.visualize: visualize(output, f'{args.dir}/{args.output}/{idx}-clip-v{j}.png') print(f'Generated completion of idx {idx}. Sampling time: {time.time()-t0} seconds') if __name__ == '__main__': parser = ArgumentParser(description='generate prompted completions') parser.add_argument('dir', help='directory containing an index of MIDI files') parser.add_argument('model', help='directory containing an model checkpoint') parser.add_argument('-o', '--output', type=str, default='model', help='model description (the name of the output subdirectory)') parser.add_argument('-c', '--count', type=int, default=10, help='number of clips to sample') parser.add_argument('-f', '--offset', type=int, default=0, help='offset for sampling (manual hack for parallel workers)') parser.add_argument('-m', '--multiplicity', type=int, default=1, help='number of generations per clip') parser.add_argument('-l', '--clip_length', type=int, default=20, help='length of the full clip (in seconds)') parser.add_argument('-v', '--visualize', action='store_true', help='plot visualizations') main(parser.parse_args())