File size: 2,648 Bytes
151b875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from argparse import ArgumentParser
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from anticipation.vocab import MIDI_TIME_OFFSET, MIDI_START_OFFSET, TIME_RESOLUTION, SEPARATOR
from anticipation.ops import max_time
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Computer Modern']
plt.rcParams['font.size'] = 16
def loghist(filename, data, title, xlabel):
sns.set_style('whitegrid')
plt.clf()
plt.figure(figsize=(10,4))
#plt.title(title)
plt.xscale('log')
plt.xlabel(xlabel)
plt.ylabel('Density')
plt.grid(True, which='both', linestyle='-', linewidth=0.5)
density = sns.kdeplot(data, bw_adjust=1.0)
plt.tight_layout()
fig = density.get_figure()
fig.savefig(filename, dpi=300)
if __name__ == '__main__':
parser = ArgumentParser(description='calculate statistics of a tokenized MIDI dataset')
parser.add_argument('-f', '--filename',
help='file containing a tokenized MIDI dataset')
parser.add_argument('-i', '--interarrival',
action='store_true',
help='request interarrival-time enocoding (default to arrival-time encoding)')
args = parser.parse_args()
print(f'Calculating statistics for {args.filename}')
time_lengths = []
token_counts = []
with open(args.filename, 'r') as f:
for i,line in tqdm(list(enumerate(f))):
if i % 10 != 0: continue
tokens = [int(token) for token in line.split()]
if args.interarrival:
time_lengths.append(sum(t-MIDI_TIME_OFFSET for t in tokens if t < MIDI_START_OFFSET))
token_counts.append(len(tokens))
else:
if SEPARATOR in tokens:
continue # counts are weird; just skip these
time_lengths.append(max_time(tokens[1:], seconds=False))
token_counts.append(len(tokens[1:]))
tokens_per_second = [TIME_RESOLUTION*tokens/float(time) for (tokens, time) in zip(token_counts, time_lengths)]
print('Total tokens:', sum(token_counts))
print(f'Total time: {float(sum(time_lengths))/(3600*TIME_RESOLUTION)} hours')
print('Mean tokens-per-second:', TIME_RESOLUTION*sum(token_counts)/float(sum(time_lengths)))
print('Std tokens-per-second:', np.std(tokens_per_second))
print(np.mean(tokens_per_second))
loghist('output/tokens_per_second.png',
tokens_per_second,
'Distribution of Tokens per Second',
'Tokens per Second (log10 scale)')
|