File size: 4,288 Bytes
151b875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os,csv
from argparse import ArgumentParser
from glob import glob
import numpy as np
from tqdm import tqdm
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION, EVENT_SIZE
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
def select_sample(filenames, prompt_length, clip_length, verbose=False):
while True:
# sampling with replacement
idx = np.random.randint(len(filenames))
if verbose:
print('Loading index: ', idx)
try:
events = midi_to_events(filenames[idx])
except Exception:
continue
max_time = ops.max_time(events) - clip_length
# don't sample tracks with length shorter than clip_length
if max_time < 0:
if verbose:
print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
continue
start_time = max_time*np.random.rand(1)[0]
clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
instruments = ops.get_instruments(clip).keys()
if len(instruments) > 15:
if verbose:
print(f' rejected: track instrument count out of bounds: {len(instruments)}')
continue
prompt = ops.clip(clip, 0, prompt_length, clip_duration=False)
# get clips with at least 10 events in the prompt
if len(prompt) < EVENT_SIZE*10:
if verbose:
print(f' rejected: track has {len(prompt)//EVENT_SIZE} < 10 events in the prompt')
continue
break # found one
return os.path.basename(filenames[idx]), clip, prompt
def main(args):
np.random.seed(args.seed)
print(f'Selecting clips for accompaniment from: {args.dir}')
filenames = glob(args.dir + '/**/*.mid', recursive=True) \
+ glob(args.dir + '/**/*.midi', recursive=True)
filenames = sorted(filenames)
print(f'Saving clips to: {args.output}')
try:
os.makedirs(args.output)
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/groundtruth')
except FileExistsError:
pass
with open(f'{args.output}/index.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['idx', 'original', 'prompt', 'parts'])
for i in tqdm(range(args.count)):
filename, clip, prompt = select_sample(filenames, args.prompt_length, args.clip_length)
parts = ops.get_instruments(clip).keys()
writer.writerow([i, filename, f'{i}-conditional.mid', len(parts)])
mid = events_to_midi(clip)
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
if args.visualize:
visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
mid = events_to_midi(prompt)
mid.save(f'{args.output}/{i}-conditional.mid')
if args.visualize:
visualize(prompt, f'{args.output}/{i}-conditional.png')
if __name__ == '__main__':
parser = ArgumentParser(description='select prompts for infilling completion human eval')
parser.add_argument('dir', help='directory containing MIDI files to sample')
parser.add_argument('-o', '--output', type=str, default='output',
help='output directory')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed for sampling')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-p', '--prompt_length', type=int, default=5,
help='length of the prompt (in seconds)')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())
|