beatalignment / humaneval /accompany.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os, csv, time
from argparse import ArgumentParser
import numpy as np
from transformers import AutoModelForCausalLM
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.sample import generate, generate_ar
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION
np.random.seed(0)
def main(args):
if args.anticipatory or args.baseline:
print(f'Accompaniment using model checkpoint: {args.model}')
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
print(f'Loaded model ({time.time()-t0} seconds)')
if args.anticipatory:
print(f'Writing outputs to {args.dir}/anticipatory')
try:
os.makedirs(f'{args.dir}/anticipatory')
except FileExistsError:
pass
if args.baseline:
print(f'Writing outputs to {args.dir}/autoregressive')
try:
os.makedirs(f'{args.dir}/autoregressive')
except FileExistsError:
pass
if args.retrieve:
print(f'Writing outputs to {args.dir}/retrieved')
try:
os.makedirs(f'{args.dir}/retrieved')
except FileExistsError:
pass
print(f'Accompanying tracks in index : {args.dir}/index.csv')
with open(f'{args.dir}/index.csv', newline='') as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
original = os.path.join(args.midis, row[header.index('original')])
conditional_midi = row[header.index('conditional')]
melody = int(row[header.index('melody')])
idx = int(row[header.index('idx')])
events = midi_to_events(os.path.join(args.dir, conditional_midi))
events, controls = extract_instruments(events, [melody])
prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)
for j in range(args.multiplicity):
t0 = time.time()
if args.anticipatory:
generated_tokens = generate(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
output = ops.clip(ops.combine(generated_tokens, controls), 0, args.clip_length)
mid = events_to_midi(output)
mid.save(f'{args.dir}/anticipatory/{idx}-clip-v{j}.mid')
if args.visualize:
visualize(output, f'{args.dir}/anticipatory/{idx}-clip-v{j}.png')
if args.baseline:
generated_tokens = generate_ar(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
output = ops.clip(generated_tokens, 0, args.clip_length)
print(len(generated_tokens), len(output))
mid = events_to_midi(output)
mid.save(f'{args.dir}/autoregressive/{idx}-clip-v{j}.mid')
if args.visualize:
visualize(output, f'{args.dir}/autoregressive/{idx}-clip-v{j}.png')
if args.retrieve:
original_events = midi_to_events(original)
max_time = ops.max_time(original_events) - args.clip_length
start_time = max_time*np.random.rand(1)[0] # get a different random clip
retrieved = ops.clip(original_events, start_time, start_time+args.clip_length, clip_duration=True)
retrieved = ops.translate(retrieved, -int(TIME_RESOLUTION*start_time))
events, _ = extract_instruments(retrieved, [melody])
generated = prompt + ops.clip(events, args.prompt_length, args.clip_length)
output = ops.combine(generated, controls)
mid = events_to_midi(output)
mid.save(f'{args.dir}/retrieved/{idx}-clip-v{j}.mid')
if args.visualize:
visualize(output, f'{args.dir}/retrieved/{idx}-clip-v{j}.png')
print(f'Accompanied with instrument {melody}. Sampling time: {time.time()-t0} seconds')
if __name__ == '__main__':
parser = ArgumentParser(description='generate infilling completions')
parser.add_argument('dir', help='directory containing an index of MIDI files')
parser.add_argument('--model', type=str, default='',
help='directory containing an anticipatory model checkpoint')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-m', '--multiplicity', type=int, default=1,
help='number of generations per clip')
parser.add_argument('-p', '--prompt_length', type=int, default=5,
help='length of the prompt (in seconds)')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-a', '--anticipatory', action='store_true',
help='generate anticipatory results')
parser.add_argument('-b', '--baseline', action='store_true',
help='generate autoregressive (baseline) results')
parser.add_argument('-r', '--retrieve', action='store_true',
help='generate the retrieval baseline')
parser.add_argument('-d', '--midis', type=str, default='',
help='directory containing the reference MIDI files (for retrieval)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())