beatalignment / humaneval /prompt-interarrival.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os, csv, time
from argparse import ArgumentParser
import numpy as np
import torch
from transformers import AutoModelForCausalLM
from anticipation import ops
from anticipation.visuals import visualize
from anticipation.convert import midi_to_interarrival, interarrival_to_midi
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.vocab import MIDI_SEPARATOR,MIDI_START_OFFSET,MIDI_END_OFFSET
def main(args):
np.random.seed(args.seed)
print(f'Prompting using model checkpoint: {args.model}')
t0 = time.time()
model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
print(f'Loaded model ({time.time()-t0} seconds)')
print(f'Writing outputs to {args.dir}/{args.output}')
try:
os.makedirs(f'{args.dir}/{args.output}')
except FileExistsError:
pass
print(f'Prompting with tracks in index : {args.dir}/index.csv')
with open(f'{args.dir}/index.csv', newline='') as f:
reader = csv.reader(f)
header = next(reader)
for row in reader:
prompt_midi = row[header.index('prompt')]
idx = int(row[header.index('idx')])
prompt = midi_to_interarrival(os.path.join(args.dir, prompt_midi))
# search for the last note onset
max_idx = 0
for i,token in enumerate(prompt):
if MIDI_START_OFFSET <= token < MIDI_START_OFFSET + MIDI_END_OFFSET:
max_idx = i
prompt = prompt[:max_idx+1] # strip trailing offsets
for j in range(args.multiplicity):
t0 = time.time()
input_ids = torch.tensor([prompt]).cuda()
output = model.generate(input_ids, do_sample=True, max_length=1024, top_p=0.95, pad_token_id=MIDI_SEPARATOR)
output = output[0].cpu().tolist()
# most convenient way to operate on this stuff is to round-trip through events
mid = interarrival_to_midi(output)
events = midi_to_events(mid)
output = ops.clip(events, 0, args.clip_length)
mid = events_to_midi(output)
mid.save(f'{args.dir}/{args.output}/{idx}-clip-v{j}.mid')
if args.visualize:
visualize(output, f'{args.dir}/{args.output}/{idx}-clip-v{j}.png')
print(f'Generated completion. Sampling time: {time.time()-t0} seconds')
if __name__ == '__main__':
parser = ArgumentParser(description='generate prompted completions with an interarrival-time model')
parser.add_argument('dir', help='directory containing an index of MIDI files')
parser.add_argument('model', help='directory containing an interarrival model checkpoint')
parser.add_argument('-o', '--output', type=str, default='model',
help='model description (the name of the output subdirectory)')
parser.add_argument('-s', '--seed', type=int, default=0,
help='random seed')
parser.add_argument('-c', '--count', type=int, default=10,
help='number of clips to sample')
parser.add_argument('-m', '--multiplicity', type=int, default=1,
help='number of generations per clip')
parser.add_argument('-l', '--clip_length', type=int, default=20,
help='length of the full clip (in seconds)')
parser.add_argument('-v', '--visualize', action='store_true',
help='plot visualizations')
main(parser.parse_args())