File size: 3,560 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os, csv, time

from argparse import ArgumentParser

import numpy as np

import torch

from transformers import AutoModelForCausalLM

from anticipation import ops
from anticipation.visuals import visualize
from anticipation.convert import midi_to_interarrival, interarrival_to_midi
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.vocab import MIDI_SEPARATOR,MIDI_START_OFFSET,MIDI_END_OFFSET


def main(args):
    np.random.seed(args.seed)

    print(f'Prompting using model checkpoint: {args.model}')
    t0 = time.time()
    model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
    print(f'Loaded model ({time.time()-t0} seconds)')

    print(f'Writing outputs to {args.dir}/{args.output}')
    try:
        os.makedirs(f'{args.dir}/{args.output}')
    except FileExistsError:
        pass

    print(f'Prompting with tracks in index : {args.dir}/index.csv')
    with open(f'{args.dir}/index.csv', newline='') as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            prompt_midi = row[header.index('prompt')]
            idx = int(row[header.index('idx')])

            prompt = midi_to_interarrival(os.path.join(args.dir, prompt_midi))
            # search for the last note onset
            max_idx = 0
            for i,token in enumerate(prompt):
                if MIDI_START_OFFSET <= token < MIDI_START_OFFSET + MIDI_END_OFFSET:
                    max_idx = i

            prompt = prompt[:max_idx+1] # strip trailing offsets
            for j in range(args.multiplicity):
                t0 = time.time()

                input_ids = torch.tensor([prompt]).cuda()
                output = model.generate(input_ids, do_sample=True, max_length=1024, top_p=0.95, pad_token_id=MIDI_SEPARATOR)
                output = output[0].cpu().tolist()

                # most convenient way to operate on this stuff is to round-trip through events
                mid = interarrival_to_midi(output)
                events = midi_to_events(mid)
                output = ops.clip(events, 0, args.clip_length)
                mid = events_to_midi(output)
                mid.save(f'{args.dir}/{args.output}/{idx}-clip-v{j}.mid')
                if args.visualize:
                    visualize(output, f'{args.dir}/{args.output}/{idx}-clip-v{j}.png')

                print(f'Generated completion. Sampling time: {time.time()-t0} seconds')


if __name__ == '__main__':
    parser = ArgumentParser(description='generate prompted completions with an interarrival-time model')
    parser.add_argument('dir', help='directory containing an index of MIDI files')
    parser.add_argument('model', help='directory containing an interarrival model checkpoint')
    parser.add_argument('-o', '--output', type=str, default='model',
            help='model description (the name of the output subdirectory)')
    parser.add_argument('-s', '--seed', type=int, default=0,
            help='random seed')
    parser.add_argument('-c', '--count', type=int, default=10,
            help='number of clips to sample')
    parser.add_argument('-m', '--multiplicity', type=int, default=1,
            help='number of generations per clip')
    parser.add_argument('-l', '--clip_length', type=int, default=20,
            help='length of the full clip (in seconds)')
    parser.add_argument('-v', '--visualize', action='store_true',
            help='plot visualizations')
    main(parser.parse_args())