beatalignment / humaneval /melody-select.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os,csv
from argparse import ArgumentParser
from glob import glob
import numpy as np
from tqdm import tqdm
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
def select_sample(filenames, prompt_length, clip_length, verbose=False):
while True:
# sampling with replacement
idx = np.random.randint(len(filenames))
if verbose:
print('Loading index: ', idx)
try:
events = midi_to_events(filenames[idx])
except Exception:
continue
max_time = ops.max_time(events) - clip_length
# don't sample tracks with length shorter than clip_length
if max_time < 0:
if verbose:
print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
continue
start_time = max_time*np.random.rand(1)[0]
clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
# find an ensemble with a healthy (non-drum / effect) instrument collection
instruments = [instr for instr in ops.get_instruments(clip).keys() if instr != 128]
if len(instruments) < 4 or len(instruments) > 10:
if verbose:
print(f' rejected: track instrument count out of bounds: {len(instruments)}')
continue
# define melody as the intstrument part with the highest (non-drum, non-piano) pitchj
pitches = {}
for instr in ops.get_instruments(clip).keys():
pitches[instr] = []
for time, _, note in zip(clip[0::3],clip[1::3],clip[2::3]):
time -= TIME_OFFSET
note -= NOTE_OFFSET
instr = note//2**7
pitch = note - (2**7)*instr
pitches[instr].append(pitch)
melody = None
high = 0
for instr in ops.get_instruments(clip).keys():
if instr in [0,9] + list(range(112,129)):
continue
avg = np.mean(pitches[instr])
if avg > high:
melody = instr
high = avg
assert melody
# get clips with at least 20 notes of melody
if ops.get_instruments(clip)[melody] < 20:
if verbose:
print(' rejected: too few melodic notes')
continue
# prompt should contain the melody line
if ops.min_time(clip, seconds=True, instr=melody) > prompt_length:
if verbose:
print(' rejected: prompt does not contain the melody')
continue
# melody shouldn't end early
if ops.max_time(clip, seconds=True, instr=melody) < (clip_length-2):
if verbose:
print(' rejected: melody ends before the end of the clip')
continue
break # found one
return os.path.basename(filenames[idx]), clip, melody
def main(args):
np.random.seed(args.seed)
print(f'Selecting clips for accompaniment from: {args.dir}')
filenames = glob(args.dir + '/**/*.mid', recursive=True) \
+ glob(args.dir + '/**/*.midi', recursive=True)
filenames = sorted(filenames)
print(f'Saving clips to: {args.output}')
try:
os.makedirs(args.output)
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/groundtruth')
except FileExistsError:
pass
with open(f'{args.output}/index.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['idx', 'original', 'conditional', 'parts', 'melody'])
for i in tqdm(range(args.count)):
filename, clip, melody = select_sample(filenames, args.prompt_length, args.clip_length)
parts = ops.get_instruments(clip).keys()
writer.writerow([i, filename, f'{i}-conditional.mid', len(parts), melody])
mid = events_to_midi(clip)
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
if args.visualize:
visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
events, controls = extract_instruments(clip, [melody])
prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)
conditional_events = ops.clip(ops.combine(prompt, controls), 0, args.clip_length)
mid = events_to_midi(conditional_events)
mid.save(f'{args.output}/{i}-conditional.mid')
if args.visualize:
visualize(conditional_events, f'{args.output}/{i}-conditional.png')
if __name__ == '__main__':
parser = ArgumentParser(description='select prompts for infilling completion human eval')
parser.add_argument('dir', help='directory containing MIDI files to sample')
parser.add_argument('-o', '--output', type=str, default='output',
help='output directory')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed for sampling')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-p', '--prompt_length', type=int, default=5,
help='length of the prompt (in seconds)')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())