import os,csv from argparse import ArgumentParser from glob import glob import numpy as np from tqdm import tqdm from anticipation import ops from anticipation.visuals import visualize from anticipation.convert import midi_to_events, events_to_midi def select_prompt(filenames, clip_length, verbose=False): for figaro_filename in filenames: try: figaro = midi_to_events(figaro_filename) except Exception: continue max_time = ops.max_time(figaro) if max_time < clip_length: if verbose: print(f' rejected: FIGARO continuation is too short ({max_time} seconds)') continue figaro = ops.clip(figaro, 0, clip_length, clip_duration=True) head, tail = os.path.split(figaro_filename) try: prompt = midi_to_events(os.path.join(head, 'prompt', tail)) except Exception: continue max_time = ops.max_time(prompt) if max_time < 4: if verbose: print(f' rejected: prompt is too short ({max_time} seconds)') continue if max_time > 6: if verbose: print(f' rejected: prompt is too long ({max_time} seconds)') continue head, tail = os.path.split(figaro_filename) try: ground = midi_to_events(os.path.join(head, 'ground', tail)) except Exception: continue max_time = ops.max_time(ground) if max_time < clip_length: if verbose: print(f' rejected: ground truth continuation is too short ({max_time} seconds)') continue ground = ops.clip(ground, 0, clip_length, clip_duration=True) yield os.path.basename(figaro_filename), prompt, ground, figaro def main(args): np.random.seed(args.seed) print(f'Selecting random clips for prompting from: {args.dir}') filenames = sorted(glob(args.dir + '*.mid')) np.random.shuffle(filenames) print(f'Saving clips to: {args.output}') try: os.makedirs(args.output) except FileExistsError: pass try: os.makedirs(f'{args.output}/groundtruth') except FileExistsError: pass try: os.makedirs(f'{args.output}/figaro') except FileExistsError: pass with open(f'{args.output}/index.csv', 'w', newline='') as f: writer = csv.writer(f) writer.writerow(['idx', 'original', 'prompt']) data = select_prompt(filenames, args.clip_length, args.verbose) for i in tqdm(range(args.count)): filename, prompt, ground, figaro = next(data) writer.writerow([i, filename, f'{i}-prompt.mid']) mid = events_to_midi(prompt) mid.save(f'{args.output}/{i}-prompt.mid') if args.visualize: visualize(prompt, f'{args.output}/{i}-prompt.png') mid = events_to_midi(ground) mid.save(f'{args.output}/groundtruth/{i}-clip.mid') if args.visualize: visualize(ground, f'{args.output}/groundtruth/{i}-clip.png') mid = events_to_midi(figaro) mid.save(f'{args.output}/figaro/{i}-clip.mid') if args.visualize: visualize(figaro, f'{args.output}/figaro/{i}-clip.png') if __name__ == '__main__': parser = ArgumentParser(description='select prompts for completion human eval') parser.add_argument('dir', help='directory containing MIDI files to sample') parser.add_argument('-o', '--output', type=str, default='prompt', help='output directory') parser.add_argument('-s', '--seed', type=int, default=0, help='random seed for prompt selection') parser.add_argument('-c', '--count', type=int, default=10, help='number of clips to sample') parser.add_argument('-l', '--clip_length', type=int, default=20, help='length of the full clip (in seconds)') parser.add_argument('-v', '--visualize', action='store_true', help='plot visualizations') parser.add_argument('--verbose', action='store_true', help='verbose output') main(parser.parse_args())