File size: 5,900 Bytes
151b875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os,csv
from argparse import ArgumentParser
from glob import glob
import numpy as np
from tqdm import tqdm
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET
def select_sample(filenames, prompt_length, clip_length, verbose=False):
while True:
# sampling with replacement
idx = np.random.randint(len(filenames))
if verbose:
print('Loading index: ', idx)
try:
events = midi_to_events(filenames[idx])
except Exception:
continue
max_time = ops.max_time(events) - clip_length
# don't sample tracks with length shorter than clip_length
if max_time < 0:
if verbose:
print(f' rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
continue
start_time = max_time*np.random.rand(1)[0]
clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))
# find an ensemble with a healthy (non-drum / effect) instrument collection
instruments = [instr for instr in ops.get_instruments(clip).keys() if instr != 128]
if len(instruments) < 4 or len(instruments) > 10:
if verbose:
print(f' rejected: track instrument count out of bounds: {len(instruments)}')
continue
# define melody as the intstrument part with the highest (non-drum, non-piano) pitchj
pitches = {}
for instr in ops.get_instruments(clip).keys():
pitches[instr] = []
for time, _, note in zip(clip[0::3],clip[1::3],clip[2::3]):
time -= TIME_OFFSET
note -= NOTE_OFFSET
instr = note//2**7
pitch = note - (2**7)*instr
pitches[instr].append(pitch)
melody = None
high = 0
for instr in ops.get_instruments(clip).keys():
if instr in [0,9] + list(range(112,129)):
continue
avg = np.mean(pitches[instr])
if avg > high:
melody = instr
high = avg
assert melody
# get clips with at least 20 notes of melody
if ops.get_instruments(clip)[melody] < 20:
if verbose:
print(' rejected: too few melodic notes')
continue
# prompt should contain the melody line
if ops.min_time(clip, seconds=True, instr=melody) > prompt_length:
if verbose:
print(' rejected: prompt does not contain the melody')
continue
# melody shouldn't end early
if ops.max_time(clip, seconds=True, instr=melody) < (clip_length-2):
if verbose:
print(' rejected: melody ends before the end of the clip')
continue
break # found one
return os.path.basename(filenames[idx]), clip, melody
def main(args):
np.random.seed(args.seed)
print(f'Selecting clips for accompaniment from: {args.dir}')
filenames = glob(args.dir + '/**/*.mid', recursive=True) \
+ glob(args.dir + '/**/*.midi', recursive=True)
filenames = sorted(filenames)
print(f'Saving clips to: {args.output}')
try:
os.makedirs(args.output)
except FileExistsError:
pass
try:
os.makedirs(f'{args.output}/groundtruth')
except FileExistsError:
pass
with open(f'{args.output}/index.csv', 'w', newline='') as f:
writer = csv.writer(f)
writer.writerow(['idx', 'original', 'conditional', 'parts', 'melody'])
for i in tqdm(range(args.count)):
filename, clip, melody = select_sample(filenames, args.prompt_length, args.clip_length)
parts = ops.get_instruments(clip).keys()
writer.writerow([i, filename, f'{i}-conditional.mid', len(parts), melody])
mid = events_to_midi(clip)
mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
if args.visualize:
visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')
events, controls = extract_instruments(clip, [melody])
prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)
conditional_events = ops.clip(ops.combine(prompt, controls), 0, args.clip_length)
mid = events_to_midi(conditional_events)
mid.save(f'{args.output}/{i}-conditional.mid')
if args.visualize:
visualize(conditional_events, f'{args.output}/{i}-conditional.png')
if __name__ == '__main__':
parser = ArgumentParser(description='select prompts for infilling completion human eval')
parser.add_argument('dir', help='directory containing MIDI files to sample')
parser.add_argument('-o', '--output', type=str, default='output',
help='output directory')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed for sampling')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-p', '--prompt_length', type=int, default=5,
help='length of the prompt (in seconds)')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())
|