|
|
import os
|
|
|
import datetime
|
|
|
import time
|
|
|
import pickle
|
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
import torch
|
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
|
|
from anticipation.sample import generate
|
|
|
from anticipation import ops
|
|
|
|
|
|
def main(args):
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(args.model)
|
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
|
|
interval = args.interval
|
|
|
numIntervals = args.numIntervals
|
|
|
numSequences = args.sequences
|
|
|
|
|
|
stats = []
|
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
|
|
folder_path = os.path.join(args.output, timestamp)
|
|
|
|
|
|
if not os.path.exists(folder_path):
|
|
|
os.makedirs(folder_path)
|
|
|
|
|
|
log_file_path = os.path.join(folder_path, "log.txt")
|
|
|
|
|
|
with open(log_file_path, 'a') as f:
|
|
|
f.write(f'Model: {args.model}\n')
|
|
|
f.write(f'Timestamp: {timestamp}\n')
|
|
|
f.write(f'Number of sequences to generate: {numSequences}\n')
|
|
|
f.write(f'Generation interval length: {interval} sec.\n')
|
|
|
f.write(f'Number of intervals to generate per sequence: {numIntervals}\n')
|
|
|
f.write('\n')
|
|
|
|
|
|
|
|
|
for s in range(numSequences):
|
|
|
times = {}
|
|
|
tokens = []
|
|
|
start_time = 0
|
|
|
end_time = interval
|
|
|
for i in range(numIntervals):
|
|
|
start_clock = time.time()
|
|
|
tokens = generate(model, inputs=tokens, start_time=start_time, end_time=end_time, top_p=.98)
|
|
|
end_clock = time.time()
|
|
|
|
|
|
|
|
|
num_instr = len(list(ops.get_instruments(tokens).keys()))
|
|
|
|
|
|
|
|
|
if (num_instr, interval) in times:
|
|
|
times[(num_instr, interval)].append(end_clock - start_clock)
|
|
|
else:
|
|
|
times[(num_instr, interval)] = [end_clock - start_clock]
|
|
|
|
|
|
start_time += interval
|
|
|
end_time += interval
|
|
|
|
|
|
stats.append(times)
|
|
|
|
|
|
with open(log_file_path, 'a') as f:
|
|
|
f.write(f'Sample {s+1} of {numSequences}. Generation interval length: {interval} sec.\n')
|
|
|
|
|
|
f.write('\n')
|
|
|
f.write(f'In this sample, intervals were generated with the following number of instruments: {len(times)}\n')
|
|
|
f.write('\n')
|
|
|
|
|
|
f.write('Summary by number of instruments generated:\n')
|
|
|
for key, value in times.items():
|
|
|
avg_time = sum(value) / len(value)
|
|
|
num_instr, interval = key
|
|
|
f.write(f"{len(value)} generation interval(s) contained {num_instr} instruments. Average Time: {avg_time}, Average Time/Interval: {avg_time/interval}\n")
|
|
|
|
|
|
f.write('\n')
|
|
|
|
|
|
overall_avg_time = sum([sum(value) / len(value) for value in times.values()]) / len(times)
|
|
|
f.write(f"Average generation time for {interval} sec intervals across entire sequence: {overall_avg_time}\n")
|
|
|
|
|
|
f.write('\n')
|
|
|
f.write('Tokens:\n')
|
|
|
f.write('\n')
|
|
|
|
|
|
f.write(' '.join([str(tok) for tok in tokens]) + '\n')
|
|
|
f.write('\n')
|
|
|
|
|
|
stats_dump_path = os.path.join(folder_path, "stats.pickle")
|
|
|
with open(stats_dump_path, 'wb') as f:
|
|
|
pickle.dump(stats, f)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = ArgumentParser(description='benchmark a tripletmidi anticipatory music transformer')
|
|
|
parser.add_argument('-m', '--model', help='checkpoint for the model to evaluate')
|
|
|
parser.add_argument('-o', '--output', help='output file for samples and logs')
|
|
|
parser.add_argument('-N', '--sequences', type=int, default=100,
|
|
|
help='number of sequences to generate')
|
|
|
parser.add_argument('-s', '--seed', type=int, default=42,
|
|
|
help='rng seed for sampling')
|
|
|
parser.add_argument('-I', '--interval', type=int, default=1,
|
|
|
help='generation interval in seconds')
|
|
|
parser.add_argument('-n', '--numIntervals', type=int, default=25,
|
|
|
help='number of intervals to generate per sequence')
|
|
|
parser.add_argument('--debug', action='store_true', help='verbose debugging outputs')
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
main(args)
|
|
|
|