File size: 4,349 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os,csv

from argparse import ArgumentParser
from glob import glob

import numpy as np
from tqdm import tqdm

from anticipation import ops
from anticipation.visuals import visualize
from anticipation.convert import midi_to_events, events_to_midi

def select_prompt(filenames, clip_length, verbose=False):
    for figaro_filename in filenames:
        try:
            figaro = midi_to_events(figaro_filename)
        except Exception:
            continue

        max_time = ops.max_time(figaro)
        if max_time < clip_length:
            if verbose:
                print(f'  rejected: FIGARO continuation is too short ({max_time} seconds)')
            continue

        figaro = ops.clip(figaro, 0, clip_length, clip_duration=True)

        head, tail = os.path.split(figaro_filename)
        try:
            prompt = midi_to_events(os.path.join(head, 'prompt', tail))
        except Exception:
            continue

        max_time = ops.max_time(prompt)

        if max_time < 4:
            if verbose:
                print(f'  rejected: prompt is too short ({max_time} seconds)')
            continue

        if max_time > 6:
            if verbose:
                print(f'  rejected: prompt is too long ({max_time} seconds)')
            continue

        head, tail = os.path.split(figaro_filename)
        try:
            ground = midi_to_events(os.path.join(head, 'ground', tail))
        except Exception:
            continue

        max_time = ops.max_time(ground)
        if max_time < clip_length:
            if verbose:
                print(f'  rejected: ground truth continuation is too short ({max_time} seconds)')
            continue

        ground = ops.clip(ground, 0, clip_length, clip_duration=True)

        yield os.path.basename(figaro_filename), prompt, ground, figaro 


def main(args):
    np.random.seed(args.seed)

    print(f'Selecting random clips for prompting from: {args.dir}')
    filenames = sorted(glob(args.dir + '*.mid'))
    np.random.shuffle(filenames)

    print(f'Saving clips to: {args.output}')
    try:
        os.makedirs(args.output)
    except FileExistsError:
        pass

    try:
        os.makedirs(f'{args.output}/groundtruth')
    except FileExistsError:
        pass

    try:
        os.makedirs(f'{args.output}/figaro')
    except FileExistsError:
        pass

    with open(f'{args.output}/index.csv', 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['idx', 'original', 'prompt'])

        data = select_prompt(filenames, args.clip_length, args.verbose)
        for i in tqdm(range(args.count)):
            filename, prompt, ground, figaro = next(data)
            writer.writerow([i, filename, f'{i}-prompt.mid'])

            mid = events_to_midi(prompt)
            mid.save(f'{args.output}/{i}-prompt.mid')
            if args.visualize:
                visualize(prompt, f'{args.output}/{i}-prompt.png')

            mid = events_to_midi(ground)
            mid.save(f'{args.output}/groundtruth/{i}-clip.mid')
            if args.visualize:
                visualize(ground, f'{args.output}/groundtruth/{i}-clip.png')

            mid = events_to_midi(figaro)
            mid.save(f'{args.output}/figaro/{i}-clip.mid')
            if args.visualize:
                visualize(figaro, f'{args.output}/figaro/{i}-clip.png')


if __name__ == '__main__':
    parser = ArgumentParser(description='select prompts for completion human eval')
    parser.add_argument('dir', help='directory containing MIDI files to sample')
    parser.add_argument('-o', '--output', type=str, default='prompt',
            help='output directory')
    parser.add_argument('-s', '--seed', type=int, default=0,
            help='random seed for prompt selection')
    parser.add_argument('-c', '--count', type=int, default=10,
            help='number of clips to sample')
    parser.add_argument('-l', '--clip_length', type=int, default=20,
            help='length of the full clip (in seconds)')
    parser.add_argument('-v', '--visualize', action='store_true',
            help='plot visualizations')
    parser.add_argument('--verbose', action='store_true',
            help='verbose output')
    main(parser.parse_args())