File size: 4,349 Bytes
151b875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os,csv
from argparse import ArgumentParser
from glob import glob
import numpy as np
from tqdm import tqdm
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.convert import midi_to_events, events_to_midi
def select_prompt(filenames, clip_length, verbose=False):
for figaro_filename in filenames:
try:
figaro = midi_to_events(figaro_filename)
except Exception:
continue
max_time = ops.max_time(figaro)
if max_time < clip_length:
if verbose:
print(f' rejected: FIGARO continuation is too short ({max_time} seconds)')
continue
figaro = ops.clip(figaro, 0, clip_length, clip_duration=True)
head, tail = os.path.split(figaro_filename)
try:
prompt = midi_to_events(os.path.join(head, 'prompt', tail))
except Exception:
continue
max_time = ops.max_time(prompt)
if max_time < 4:
if verbose:
print(f' rejected: prompt is too short ({max_time} seconds)')
continue
if max_time > 6:
if verbose:
print(f' rejected: prompt is too long ({max_time} seconds)')
continue
head, tail = os.path.split(figaro_filename)
try:
ground = midi_to_events(os.path.join(head, 'ground', tail))
except Exception:
continue
max_time = ops.max_time(ground)
if max_time < clip_length:
if verbose:
print(f' rejected: ground truth continuation is too short ({max_time} seconds)')
continue
ground = ops.clip(ground, 0, clip_length, clip_duration=True)
yield os.path.basename(figaro_filename), prompt, ground, figaro
def main(args):
np.random.seed(args.seed)
print(f'Selecting random clips for prompting from: {args.dir}')
filenames = sorted(glob(args.dir + '*.mid'))
np.random.shuffle(filenames)
print(f'Saving clips to: {args.output}')
try:
os.makedirs(args.output)
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/groundtruth')
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/figaro')
except FileExistsError:
pass
with open(f'{args.output}/index.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['idx', 'original', 'prompt'])
data = select_prompt(filenames, args.clip_length, args.verbose)
for i in tqdm(range(args.count)):
filename, prompt, ground, figaro = next(data)
writer.writerow([i, filename, f'{i}-prompt.mid'])
mid = events_to_midi(prompt)
mid.save(f'{args.output}/{i}-prompt.mid')
if args.visualize:
visualize(prompt, f'{args.output}/{i}-prompt.png')
mid = events_to_midi(ground)
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
if args.visualize:
visualize(ground, f'{args.output}/groundtruth/{i}-clip.png')
mid = events_to_midi(figaro)
mid.save(f'{args.output}/figaro/{i}-clip.mid')
if args.visualize:
visualize(figaro, f'{args.output}/figaro/{i}-clip.png')
if __name__ == '__main__':
parser = ArgumentParser(description='select prompts for completion human eval')
parser.add_argument('dir', help='directory containing MIDI files to sample')
parser.add_argument('-o', '--output', type=str, default='prompt',
help='output directory')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed for prompt selection')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
parser.add_argument('--verbose', action='store_true',
help='verbose output')
main(parser.parse_args())
|