beatalignment / humaneval /clip-select.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os,csv
from argparse import ArgumentParser
from glob import glob
import numpy as np
from tqdm import tqdm
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION, EVENT_SIZE
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
def select_sample(filenames, prompt_length, clip_length, verbose=False):
while True:
# sampling with replacement
idx = np.random.randint(len(filenames))
if verbose:
print('Loading index: ', idx)
try:
events = midi_to_events(filenames[idx])
except Exception:
continue
max_time = ops.max_time(events) - clip_length
# don't sample tracks with length shorter than clip_length
if max_time < 0:
if verbose:
print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
continue
start_time = max_time*np.random.rand(1)[0]
clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
instruments = ops.get_instruments(clip).keys()
if len(instruments) > 15:
if verbose:
print(f' rejected: track instrument count out of bounds: {len(instruments)}')
continue
prompt = ops.clip(clip, 0, prompt_length, clip_duration=False)
# get clips with at least 10 events in the prompt
if len(prompt) < EVENT_SIZE*10:
if verbose:
print(f' rejected: track has {len(prompt)//EVENT_SIZE} < 10 events in the prompt')
continue
break # found one
return os.path.basename(filenames[idx]), clip, prompt
def main(args):
np.random.seed(args.seed)
print(f'Selecting clips for accompaniment from: {args.dir}')
filenames = glob(args.dir + '/**/*.mid', recursive=True) \
+ glob(args.dir + '/**/*.midi', recursive=True)
filenames = sorted(filenames)
print(f'Saving clips to: {args.output}')
try:
os.makedirs(args.output)
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/groundtruth')
except FileExistsError:
pass
with open(f'{args.output}/index.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['idx', 'original', 'prompt', 'parts'])
for i in tqdm(range(args.count)):
filename, clip, prompt = select_sample(filenames, args.prompt_length, args.clip_length)
parts = ops.get_instruments(clip).keys()
writer.writerow([i, filename, f'{i}-conditional.mid', len(parts)])
mid = events_to_midi(clip)
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
if args.visualize:
visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
mid = events_to_midi(prompt)
mid.save(f'{args.output}/{i}-conditional.mid')
if args.visualize:
visualize(prompt, f'{args.output}/{i}-conditional.png')
if __name__ == '__main__':
parser = ArgumentParser(description='select prompts for infilling completion human eval')
parser.add_argument('dir', help='directory containing MIDI files to sample')
parser.add_argument('-o', '--output', type=str, default='output',
help='output directory')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed for sampling')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-p', '--prompt_length', type=int, default=5,
help='length of the prompt (in seconds)')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())