File size: 2,648 Bytes
151b875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from argparse import ArgumentParser
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from anticipation.vocab import MIDI_TIME_OFFSET, MIDI_START_OFFSET, TIME_RESOLUTION, SEPARATOR
from anticipation.ops import max_time

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams['font.size'] = 16

def loghist(filename, data, title, xlabel):
    sns.set_style('whitegrid')
    plt.clf()
    plt.figure(figsize=(10,4))
    #plt.title(title)
    plt.xscale('log')
    plt.xlabel(xlabel)
    plt.ylabel('Density')

    plt.grid(True, which='both', linestyle='-', linewidth=0.5)

    density = sns.kdeplot(data, bw_adjust=1.0)

    plt.tight_layout()
    fig = density.get_figure()
    fig.savefig(filename, dpi=300)


if __name__ == '__main__':
    parser = ArgumentParser(description='calculate statistics of a tokenized MIDI dataset')
    parser.add_argument('-f', '--filename',
                        help='file containing a tokenized MIDI dataset')
    parser.add_argument('-i', '--interarrival',
            action='store_true',
            help='request interarrival-time enocoding (default to arrival-time encoding)')
    args = parser.parse_args()

    print(f'Calculating statistics for {args.filename}')
    time_lengths  = []
    token_counts = []
    with open(args.filename, 'r') as f:
        for i,line in tqdm(list(enumerate(f))):
            if i % 10 != 0: continue
            tokens = [int(token) for token in line.split()]

            if args.interarrival:
                time_lengths.append(sum(t-MIDI_TIME_OFFSET for t in tokens if t < MIDI_START_OFFSET))
                token_counts.append(len(tokens))
            else:
                if SEPARATOR in tokens:
                    continue # counts are weird; just skip these
                time_lengths.append(max_time(tokens[1:], seconds=False))
                token_counts.append(len(tokens[1:]))

    tokens_per_second = [TIME_RESOLUTION*tokens/float(time) for (tokens, time) in zip(token_counts, time_lengths)]
    print('Total tokens:', sum(token_counts))
    print(f'Total time: {float(sum(time_lengths))/(3600*TIME_RESOLUTION)} hours')
    print('Mean tokens-per-second:', TIME_RESOLUTION*sum(token_counts)/float(sum(time_lengths)))
    print('Std tokens-per-second:', np.std(tokens_per_second))
    print(np.mean(tokens_per_second))

    loghist('output/tokens_per_second.png',
            tokens_per_second,
            'Distribution of Tokens per Second',
            'Tokens per Second (log10 scale)')