|
|
import os,csv
|
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
from glob import glob
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
from anticipation import ops
|
|
|
from anticipation.visuals import visualize
|
|
|
from anticipation.tokenize import extract_instruments
|
|
|
from anticipation.convert import midi_to_events, events_to_midi
|
|
|
from anticipation.config import TIME_RESOLUTION, EVENT_SIZE
|
|
|
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
|
|
|
|
|
|
def select_sample(filenames, prompt_length, clip_length, verbose=False):
|
|
|
while True:
|
|
|
|
|
|
idx = np.random.randint(len(filenames))
|
|
|
if verbose:
|
|
|
print('Loading index: ', idx)
|
|
|
|
|
|
try:
|
|
|
events = midi_to_events(filenames[idx])
|
|
|
except Exception:
|
|
|
continue
|
|
|
|
|
|
max_time = ops.max_time(events) - clip_length
|
|
|
|
|
|
|
|
|
if max_time < 0:
|
|
|
if verbose:
|
|
|
print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
|
|
|
continue
|
|
|
|
|
|
start_time = max_time*np.random.rand(1)[0]
|
|
|
clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
|
|
|
clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
|
|
|
|
|
|
instruments = ops.get_instruments(clip).keys()
|
|
|
if len(instruments) > 15:
|
|
|
if verbose:
|
|
|
print(f' rejected: track instrument count out of bounds: {len(instruments)}')
|
|
|
continue
|
|
|
|
|
|
prompt = ops.clip(clip, 0, prompt_length, clip_duration=False)
|
|
|
|
|
|
|
|
|
if len(prompt) < EVENT_SIZE*10:
|
|
|
if verbose:
|
|
|
print(f' rejected: track has {len(prompt)//EVENT_SIZE} < 10 events in the prompt')
|
|
|
continue
|
|
|
|
|
|
break
|
|
|
|
|
|
return os.path.basename(filenames[idx]), clip, prompt
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
np.random.seed(args.seed)
|
|
|
|
|
|
print(f'Selecting clips for accompaniment from: {args.dir}')
|
|
|
filenames = glob(args.dir + '/**/*.mid', recursive=True) \
|
|
|
+ glob(args.dir + '/**/*.midi', recursive=True)
|
|
|
filenames = sorted(filenames)
|
|
|
|
|
|
print(f'Saving clips to: {args.output}')
|
|
|
try:
|
|
|
os.makedirs(args.output)
|
|
|
except FileExistsError:
|
|
|
pass
|
|
|
|
|
|
try:
|
|
|
os.makedirs(f'{args.output}/groundtruth')
|
|
|
except FileExistsError:
|
|
|
pass
|
|
|
|
|
|
with open(f'{args.output}/index.csv', 'w', newline='') as f:
|
|
|
writer = csv.writer(f)
|
|
|
writer.writerow(['idx', 'original', 'prompt', 'parts'])
|
|
|
|
|
|
for i in tqdm(range(args.count)):
|
|
|
filename, clip, prompt = select_sample(filenames, args.prompt_length, args.clip_length)
|
|
|
parts = ops.get_instruments(clip).keys()
|
|
|
writer.writerow([i, filename, f'{i}-conditional.mid', len(parts)])
|
|
|
|
|
|
mid = events_to_midi(clip)
|
|
|
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
|
|
|
if args.visualize:
|
|
|
visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
|
|
|
|
|
|
mid = events_to_midi(prompt)
|
|
|
mid.save(f'{args.output}/{i}-conditional.mid')
|
|
|
if args.visualize:
|
|
|
visualize(prompt, f'{args.output}/{i}-conditional.png')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = ArgumentParser(description='select prompts for infilling completion human eval')
|
|
|
parser.add_argument('dir', help='directory containing MIDI files to sample')
|
|
|
parser.add_argument('-o', '--output', type=str, default='output',
|
|
|
help='output directory')
|
|
|
parser.add_argument('-s', '--seed', type=int, default=0,
|
|
|
help='random seed for sampling')
|
|
|
parser.add_argument('-c', '--count', type=int, default=10,
|
|
|
help='number of clips to sample')
|
|
|
parser.add_argument('-p', '--prompt_length', type=int, default=5,
|
|
|
help='length of the prompt (in seconds)')
|
|
|
parser.add_argument('-l', '--clip_length', type=int, default=20,
|
|
|
help='length of the full clip (in seconds)')
|
|
|
parser.add_argument('-v', '--visualize', action='store_true',
|
|
|
help='plot visualizations')
|
|
|
main(parser.parse_args())
|
|
|
|