File size: 5,900 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os,csv

from argparse import ArgumentParser
from glob import glob

import numpy as np

from tqdm import tqdm

from anticipation import ops
from anticipation.visuals import visualize
from anticipation.tokenize import extract_instruments
from anticipation.convert import midi_to_events, events_to_midi
from anticipation.config import TIME_RESOLUTION
from anticipation.vocab import TIME_OFFSET, NOTE_OFFSET

def select_sample(filenames, prompt_length, clip_length, verbose=False):
    while True:
        # sampling with replacement
        idx = np.random.randint(len(filenames))
        if verbose:
            print('Loading index: ', idx)

        try:
            events = midi_to_events(filenames[idx])
        except Exception:
            continue

        max_time = ops.max_time(events) - clip_length

        # don't sample tracks with length shorter than clip_length
        if max_time < 0:
            if verbose:
                print(f'  rejected: track is too short (length {ops.max_time(events)} < {clip_length})')
            continue

        start_time = max_time*np.random.rand(1)[0]
        clip = ops.clip(events, start_time, start_time+clip_length, clip_duration=True)
        clip = ops.translate(clip, -int(TIME_RESOLUTION*start_time))

        # find an ensemble with a healthy (non-drum / effect) instrument collection
        instruments = [instr for instr in ops.get_instruments(clip).keys() if instr != 128]
        if len(instruments) < 4 or len(instruments) > 10:
            if verbose:
                print(f'  rejected: track instrument count out of bounds: {len(instruments)}')
            continue

        # define melody as the intstrument part with the highest (non-drum, non-piano) pitchj
        pitches = {}
        for instr in ops.get_instruments(clip).keys():
            pitches[instr] = []

        for time, _, note in zip(clip[0::3],clip[1::3],clip[2::3]):
            time -= TIME_OFFSET
            note -= NOTE_OFFSET

            instr = note//2**7
            pitch = note - (2**7)*instr

            pitches[instr].append(pitch)

        melody = None
        high = 0
        for instr in ops.get_instruments(clip).keys():
            if instr in [0,9] + list(range(112,129)):
                continue 

            avg = np.mean(pitches[instr])
            if avg > high:
                melody = instr
                high = avg

        assert melody

        # get clips with at least 20 notes of melody
        if ops.get_instruments(clip)[melody] < 20:
            if verbose:
                print('  rejected: too few melodic notes')
            continue

        # prompt should contain the melody line
        if ops.min_time(clip, seconds=True, instr=melody) > prompt_length:
            if verbose:
                print('  rejected: prompt does not contain the melody')
            continue

        # melody shouldn't end early
        if ops.max_time(clip, seconds=True, instr=melody) < (clip_length-2):
            if verbose:
                print('  rejected: melody ends before the end of the clip')
            continue

        break # found one

    return os.path.basename(filenames[idx]), clip, melody


def main(args):
    np.random.seed(args.seed)

    print(f'Selecting clips for accompaniment from: {args.dir}')
    filenames = glob(args.dir + '/**/*.mid', recursive=True) \
            + glob(args.dir + '/**/*.midi', recursive=True)
    filenames = sorted(filenames)

    print(f'Saving clips to: {args.output}')
    try:
        os.makedirs(args.output)
    except FileExistsError:
        pass

    try:
        os.makedirs(f'{args.output}/groundtruth')
    except FileExistsError:
        pass

    with open(f'{args.output}/index.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['idx', 'original', 'conditional', 'parts', 'melody'])

        for i in tqdm(range(args.count)):
            filename, clip, melody = select_sample(filenames, args.prompt_length, args.clip_length)
            parts = ops.get_instruments(clip).keys()
            writer.writerow([i, filename, f'{i}-conditional.mid', len(parts), melody])

            mid = events_to_midi(clip)
            mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
            if args.visualize:
                visualize(clip, f'{args.output}/groundtruth/{i}-clip.png')

            events, controls = extract_instruments(clip, [melody])
            prompt = ops.clip(events, 0, args.prompt_length, clip_duration=False)

            conditional_events = ops.clip(ops.combine(prompt, controls), 0, args.clip_length)
            mid = events_to_midi(conditional_events)
            mid.save(f'{args.output}/{i}-conditional.mid')
            if args.visualize:
                visualize(conditional_events, f'{args.output}/{i}-conditional.png')


if __name__ == '__main__':
    parser = ArgumentParser(description='select prompts for infilling completion human eval')
    parser.add_argument('dir', help='directory containing MIDI files to sample')
    parser.add_argument('-o', '--output', type=str, default='output',
            help='output directory')
    parser.add_argument('-s', '--seed', type=int, default=0,
            help='random seed for sampling')
    parser.add_argument('-c', '--count', type=int, default=10,
            help='number of clips to sample')
    parser.add_argument('-p', '--prompt_length', type=int, default=5,
            help='length of the prompt (in seconds)')
    parser.add_argument('-l', '--clip_length', type=int, default=20,
            help='length of the full clip (in seconds)')
    parser.add_argument('-v', '--visualize', action='store_true',
            help='plot visualizations')
    main(parser.parse_args())