import os import numpy as np import pandas as pd from multiprocessing import Pool, RLock, Manager from glob import glob from tqdm import tqdm from anticipation.config import * from anticipation.tokenize import tokenize2, tokenize3 PREPROC_WORKERS = 32 def main(): print('Tokenization parameters:') print(f' anticipation interval = {DELTA}s') print(f' max track length = {MAX_TRACK_TIME_IN_SECONDS}s') print(f' min track length = {MIN_TRACK_TIME_IN_SECONDS}s') print(f' min track events = {MIN_TRACK_EVENTS}') BASE = "./asap-dataset-master/" df = pd.read_csv('asap-dataset-master/metadata.csv') datafiles = [] # collect file tuples with progress bar for _, row in tqdm(df.iterrows(), total=len(df), desc='Reading metadata', unit='file'): file1 = BASE + row['midi_performance'] file2 = BASE + row['midi_score'] file3 = BASE + row['performance_annotations'] file4 = BASE + row['midi_score_annotations'] datafiles.append((file1, file2, file3, file4)) np.random.shuffle(datafiles) # shuffle the datafiles print(f'Parallel tokenizing data with {PREPROC_WORKERS} workers') # split datafiles into chunks for each worker chunks = np.array_split(datafiles, PREPROC_WORKERS) outputs = [f'./data/output_{i}.txt' for i in range(PREPROC_WORKERS)] # run tokenize3 on each chunk in parallel args = [(list(chunk), out, True) for chunk, out in zip(chunks, outputs)] with Pool(processes=PREPROC_WORKERS, initargs=(RLock(),), initializer=tqdm.set_lock) as pool: manager = Manager() results = [] # overall progress bar for chunks total_pbar = tqdm(total=len(args), desc='Overall tokenization', position=0, leave=True) def _cb(res): results.append(res) total_pbar.update(1) # launch tasks for arg in args: pool.apply_async(tokenize3, args=arg, callback=_cb) pool.close() pool.join() total_pbar.close() # merge outputs merger = './data/output.txt' with open(merger, 'w') as wf: for out in outputs: with open(out, 'r') as rf: wf.write(rf.read()) os.remove(out) # aggregate stats seq_count, rest_count, too_short, too_long, too_manyinstr, discarded_seqs, truncations = map(sum, zip(*results)) rest_ratio = round(100*float(rest_count)/(seq_count*M), 2) trunc_type = 'duration' #'interarrival' if args.interarrival else 'duration' trunc_ratio = round(100*float(truncations)/(seq_count*M), 2) print('Tokenization complete.') print(f' => Processed {seq_count} training sequences') print(f' => Inserted {rest_count} REST tokens ({rest_ratio}% of events)') print(f' => Discarded {too_short+too_long} event sequences') print(f' - {too_short} too short') print(f' - {too_long} too long') print(f' - {too_manyinstr} too many instruments') print(f' => Discarded {discarded_seqs} training sequences') print(f' => Truncated {truncations} {trunc_type} times ({trunc_ratio}% of {trunc_type}s)') print('Remember to shuffle the training split!') if __name__ == '__main__': main()