File size: 5,792 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os, csv, time

from argparse import ArgumentParser

import numpy as np

from transformers import AutoModelForCausalLM

from anticipation import ops
from anticipation.visuals import visualize
from anticipation.sample import generate, generate_ar
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION

np.random.seed(0)

def main(args):
    if args.anticipatory or args.baseline:
        print(f'Accompaniment using model checkpoint: {args.model}')
        t0 = time.time()
        model = AutoModelForCausalLM.from_pretrained(args.model).cuda()
        print(f'Loaded model ({time.time()-t0} seconds)')

    if args.anticipatory:
        print(f'Writing outputs to {args.dir}/anticipatory')
        try:
            os.makedirs(f'{args.dir}/anticipatory')
        except FileExistsError:
            pass

    if args.baseline:
        print(f'Writing outputs to {args.dir}/autoregressive')
        try:
            os.makedirs(f'{args.dir}/autoregressive')
        except FileExistsError:
            pass

    if args.retrieve:
        print(f'Writing outputs to {args.dir}/retrieved')
        try:
            os.makedirs(f'{args.dir}/retrieved')
        except FileExistsError:
            pass

    print(f'Accompanying tracks in index : {args.dir}/index.csv')
    with open(f'{args.dir}/index.csv', newline='') as f:
        reader = csv.reader(f)
        header = next(reader)
        for row in reader:
            original = os.path.join(args.midis, row[header.index('original')])
            conditional_midi = row[header.index('conditional')]
            melody = int(row[header.index('melody')])
            idx = int(row[header.index('idx')])

            events = midi_to_events(os.path.join(args.dir, conditional_midi))

            events, controls = extract_instruments(events, [melody])
            prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)

            for j in range(args.multiplicity):
                t0 = time.time()

                if args.anticipatory:
                    generated_tokens = generate(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
                    output = ops.clip(ops.combine(generated_tokens, controls), 0, args.clip_length)
                    mid = events_to_midi(output)
                    mid.save(f'{args.dir}/anticipatory/{idx}-clip-v{j}.mid')
                    if args.visualize:
                        visualize(output, f'{args.dir}/anticipatory/{idx}-clip-v{j}.png')

                if args.baseline:
                    generated_tokens = generate_ar(model, args.prompt_length, args.clip_length, prompt, controls, top_p=0.95)
                    output = ops.clip(generated_tokens, 0, args.clip_length)
                    print(len(generated_tokens), len(output))
                    mid = events_to_midi(output)
                    mid.save(f'{args.dir}/autoregressive/{idx}-clip-v{j}.mid')
                    if args.visualize:
                        visualize(output, f'{args.dir}/autoregressive/{idx}-clip-v{j}.png')

                if args.retrieve:
                    original_events = midi_to_events(original)
                    max_time = ops.max_time(original_events) - args.clip_length
                    start_time = max_time*np.random.rand(1)[0] # get a different random clip
                    retrieved = ops.clip(original_events, start_time, start_time+args.clip_length, clip_duration=True)
                    retrieved = ops.translate(retrieved, -int(TIME_RESOLUTION*start_time))
                    events, _ = extract_instruments(retrieved, [melody])
                    generated = prompt + ops.clip(events, args.prompt_length, args.clip_length)
                    output = ops.combine(generated, controls)
                    mid = events_to_midi(output)
                    mid.save(f'{args.dir}/retrieved/{idx}-clip-v{j}.mid')
                    if args.visualize:
                        visualize(output, f'{args.dir}/retrieved/{idx}-clip-v{j}.png')


                print(f'Accompanied with instrument {melody}. Sampling time: {time.time()-t0} seconds')


if __name__ == '__main__':
    parser = ArgumentParser(description='generate infilling completions')
    parser.add_argument('dir', help='directory containing an index of MIDI files')
    parser.add_argument('--model', type=str, default='',
            help='directory containing an anticipatory model checkpoint')
    parser.add_argument('-c', '--count', type=int, default=10,
            help='number of clips to sample')
    parser.add_argument('-m', '--multiplicity', type=int, default=1,
            help='number of generations per clip')
    parser.add_argument('-p', '--prompt_length', type=int, default=5,
            help='length of the prompt (in seconds)')
    parser.add_argument('-l', '--clip_length', type=int, default=20,
            help='length of the full clip (in seconds)')
    parser.add_argument('-a', '--anticipatory', action='store_true',
            help='generate anticipatory results')
    parser.add_argument('-b', '--baseline', action='store_true',
            help='generate autoregressive (baseline) results')
    parser.add_argument('-r', '--retrieve', action='store_true',
            help='generate the retrieval baseline')
    parser.add_argument('-d', '--midis', type=str, default='',
            help='directory containing the reference MIDI files (for retrieval)')
    parser.add_argument('-v', '--visualize', action='store_true',
            help='plot visualizations')
    main(parser.parse_args())