|
|
import os, csv, time
|
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
from anticipation import ops
|
|
|
from anticipation.visuals import visualize
|
|
|
from anticipation.sample import generate, generate_ar
|
|
|
from anticipation.tokenize import extract_instruments
|
|
|
from anticipation.convert import midi_to_events, events_to_midi
|
|
|
from anticipation.config import TIME_RESOLUTION
|
|
|
|
|
|
np.random.seed(0)
|
|
|
|
|
|
def main(args):
|
|
|
if args.anticipatory or args.baseline:
|
|
|
print(f'Accompaniment using model checkpoint: {args.model}')
|
|
|
t0 = time.time()
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
|
|
|
print(f'Loaded model ({time.time()-t0} seconds)')
|
|
|
|
|
|
if args.anticipatory:
|
|
|
print(f'Writing outputs to {args.dir}/anticipatory')
|
|
|
try:
|
|
|
os.makedirs(f'{args.dir}/anticipatory')
|
|
|
except FileExistsError:
|
|
|
pass
|
|
|
|
|
|
if args.baseline:
|
|
|
print(f'Writing outputs to {args.dir}/autoregressive')
|
|
|
try:
|
|
|
os.makedirs(f'{args.dir}/autoregressive')
|
|
|
except FileExistsError:
|
|
|
pass
|
|
|
|
|
|
if args.retrieve:
|
|
|
print(f'Writing outputs to {args.dir}/retrieved')
|
|
|
try:
|
|
|
os.makedirs(f'{args.dir}/retrieved')
|
|
|
except FileExistsError:
|
|
|
pass
|
|
|
|
|
|
print(f'Accompanying tracks in index : {args.dir}/index.csv')
|
|
|
with open(f'{args.dir}/index.csv', newline='') as f:
|
|
|
reader = csv.reader(f)
|
|
|
header = next(reader)
|
|
|
for row in reader:
|
|
|
original = os.path.join(args.midis, row[header.index('original')])
|
|
|
conditional_midi = row[header.index('conditional')]
|
|
|
melody = int(row[header.index('melody')])
|
|
|
idx = int(row[header.index('idx')])
|
|
|
|
|
|
events = midi_to_events(os.path.join(args.dir, conditional_midi))
|
|
|
|
|
|
events, controls = extract_instruments(events, [melody])
|
|
|
prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)
|
|
|
|
|
|
for j in range(args.multiplicity):
|
|
|
t0 = time.time()
|
|
|
|
|
|
if args.anticipatory:
|
|
|
generated_tokens = generate(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
|
|
|
output = ops.clip(ops.combine(generated_tokens, controls), 0, args.clip_length)
|
|
|
mid = events_to_midi(output)
|
|
|
mid.save(f'{args.dir}/anticipatory/{idx}-clip-v{j}.mid')
|
|
|
if args.visualize:
|
|
|
visualize(output, f'{args.dir}/anticipatory/{idx}-clip-v{j}.png')
|
|
|
|
|
|
if args.baseline:
|
|
|
generated_tokens = generate_ar(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
|
|
|
output = ops.clip(generated_tokens, 0, args.clip_length)
|
|
|
print(len(generated_tokens), len(output))
|
|
|
mid = events_to_midi(output)
|
|
|
mid.save(f'{args.dir}/autoregressive/{idx}-clip-v{j}.mid')
|
|
|
if args.visualize:
|
|
|
visualize(output, f'{args.dir}/autoregressive/{idx}-clip-v{j}.png')
|
|
|
|
|
|
if args.retrieve:
|
|
|
original_events = midi_to_events(original)
|
|
|
max_time = ops.max_time(original_events) - args.clip_length
|
|
|
start_time = max_time*np.random.rand(1)[0]
|
|
|
retrieved = ops.clip(original_events, start_time, start_time+args.clip_length, clip_duration=True)
|
|
|
retrieved = ops.translate(retrieved, -int(TIME_RESOLUTION*start_time))
|
|
|
events, _ = extract_instruments(retrieved, [melody])
|
|
|
generated = prompt + ops.clip(events, args.prompt_length, args.clip_length)
|
|
|
output = ops.combine(generated, controls)
|
|
|
mid = events_to_midi(output)
|
|
|
mid.save(f'{args.dir}/retrieved/{idx}-clip-v{j}.mid')
|
|
|
if args.visualize:
|
|
|
visualize(output, f'{args.dir}/retrieved/{idx}-clip-v{j}.png')
|
|
|
|
|
|
|
|
|
print(f'Accompanied with instrument {melody}. Sampling time: {time.time()-t0} seconds')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = ArgumentParser(description='generate infilling completions')
|
|
|
parser.add_argument('dir', help='directory containing an index of MIDI files')
|
|
|
parser.add_argument('--model', type=str, default='',
|
|
|
help='directory containing an anticipatory model checkpoint')
|
|
|
parser.add_argument('-c', '--count', type=int, default=10,
|
|
|
help='number of clips to sample')
|
|
|
parser.add_argument('-m', '--multiplicity', type=int, default=1,
|
|
|
help='number of generations per clip')
|
|
|
parser.add_argument('-p', '--prompt_length', type=int, default=5,
|
|
|
help='length of the prompt (in seconds)')
|
|
|
parser.add_argument('-l', '--clip_length', type=int, default=20,
|
|
|
help='length of the full clip (in seconds)')
|
|
|
parser.add_argument('-a', '--anticipatory', action='store_true',
|
|
|
help='generate anticipatory results')
|
|
|
parser.add_argument('-b', '--baseline', action='store_true',
|
|
|
help='generate autoregressive (baseline) results')
|
|
|
parser.add_argument('-r', '--retrieve', action='store_true',
|
|
|
help='generate the retrieval baseline')
|
|
|
parser.add_argument('-d', '--midis', type=str, default='',
|
|
|
help='directory containing the reference MIDI files (for retrieval)')
|
|
|
parser.add_argument('-v', '--visualize', action='store_true',
|
|
|
help='plot visualizations')
|
|
|
main(parser.parse_args())
|
|
|
|