File size: 4,442 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import datetime
import time
import pickle

from argparse import ArgumentParser
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM

from anticipation.sample import generate
from anticipation import ops

def main(args):
    # initialize the model and tokenizer
    model = AutoModelForCausalLM.from_pretrained(args.model)

    # set the device to use
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    # set the seed for reproducibility
    torch.manual_seed(args.seed)

    interval = args.interval
    numIntervals = args.numIntervals
    numSequences = args.sequences

    stats = []

    timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    folder_path = os.path.join(args.output, timestamp)

    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

    log_file_path = os.path.join(folder_path, "log.txt")

    with open(log_file_path, 'a') as f:
        f.write(f'Model: {args.model}\n')
        f.write(f'Timestamp: {timestamp}\n')
        f.write(f'Number of sequences to generate: {numSequences}\n')
        f.write(f'Generation interval length: {interval} sec.\n')
        f.write(f'Number of intervals to generate per sequence: {numIntervals}\n')
        f.write('\n')

    # generate sequences
    for s in range(numSequences):
        times = {}
        tokens = []
        start_time = 0
        end_time = interval
        for i in range(numIntervals):
            start_clock = time.time()
            tokens = generate(model, inputs=tokens, start_time=start_time, end_time=end_time, top_p=.98)
            end_clock = time.time()

            # compute number of instruments
            num_instr = len(list(ops.get_instruments(tokens).keys()))
            
            # track stats
            if (num_instr, interval) in times:
                times[(num_instr, interval)].append(end_clock - start_clock)
            else:
                times[(num_instr, interval)] = [end_clock - start_clock]

            start_time += interval
            end_time += interval
        
        stats.append(times)
        
        with open(log_file_path, 'a') as f:
            f.write(f'Sample {s+1} of {numSequences}. Generation interval length: {interval} sec.\n')

            f.write('\n')
            f.write(f'In this sample, intervals were generated with the following number of instruments: {len(times)}\n')
            f.write('\n')

            f.write('Summary by number of instruments generated:\n')
            for key, value in times.items():
                avg_time = sum(value) / len(value)
                num_instr, interval = key
                f.write(f"{len(value)} generation interval(s) contained {num_instr} instruments. Average Time: {avg_time}, Average Time/Interval: {avg_time/interval}\n")
            
            f.write('\n')

            overall_avg_time = sum([sum(value) / len(value) for value in times.values()]) / len(times)
            f.write(f"Average generation time for {interval} sec intervals across entire sequence: {overall_avg_time}\n")

            f.write('\n')
            f.write('Tokens:\n')
            f.write('\n')

            f.write(' '.join([str(tok) for tok in tokens]) + '\n')
            f.write('\n')
    
    stats_dump_path = os.path.join(folder_path, "stats.pickle")
    with open(stats_dump_path, 'wb') as f:
        pickle.dump(stats, f)

if __name__ == '__main__':
    parser = ArgumentParser(description='benchmark a tripletmidi anticipatory music transformer')
    parser.add_argument('-m', '--model', help='checkpoint for the model to evaluate')
    parser.add_argument('-o', '--output', help='output file for samples and logs')
    parser.add_argument('-N', '--sequences', type=int, default=100,
        help='number of sequences to generate')
    parser.add_argument('-s', '--seed', type=int, default=42,
        help='rng seed for sampling')
    parser.add_argument('-I', '--interval', type=int, default=1,
        help='generation interval in seconds')
    parser.add_argument('-n', '--numIntervals', type=int, default=25,
        help='number of intervals to generate per sequence')
    parser.add_argument('--debug', action='store_true', help='verbose debugging outputs')
    args = parser.parse_args()

    main(args)