beatalignment / tokenize-asap.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os
import numpy as np
import pandas as pd
from multiprocessing import Pool, RLock, Manager
from glob import glob
from tqdm import tqdm
from anticipation.config import *
from anticipation.tokenize import tokenize2, tokenize3
PREPROC_WORKERS = 32
def main():
print('Tokenization parameters:')
print(f' anticipation interval = {DELTA}s')
print(f' max track length = {MAX_TRACK_TIME_IN_SECONDS}s')
print(f' min track length = {MIN_TRACK_TIME_IN_SECONDS}s')
print(f' min track events = {MIN_TRACK_EVENTS}')
BASE = "./asap-dataset-master/"
df = pd.read_csv('asap-dataset-master/metadata.csv')
datafiles = []
# collect file tuples with progress bar
for _, row in tqdm(df.iterrows(), total=len(df), desc='Reading metadata', unit='file'):
file1 = BASE + row['midi_performance']
file2 = BASE + row['midi_score']
file3 = BASE + row['performance_annotations']
file4 = BASE + row['midi_score_annotations']
datafiles.append((file1, file2, file3, file4))
np.random.shuffle(datafiles) # shuffle the datafiles
print(f'Parallel tokenizing data with {PREPROC_WORKERS} workers')
# split datafiles into chunks for each worker
chunks = np.array_split(datafiles, PREPROC_WORKERS)
outputs = [f'./data/output_{i}.txt' for i in range(PREPROC_WORKERS)]
# run tokenize3 on each chunk in parallel
args = [(list(chunk), out, True) for chunk, out in zip(chunks, outputs)]
with Pool(processes=PREPROC_WORKERS, initargs=(RLock(),), initializer=tqdm.set_lock) as pool:
manager = Manager()
results = []
# overall progress bar for chunks
total_pbar = tqdm(total=len(args), desc='Overall tokenization', position=0, leave=True)
def _cb(res):
results.append(res)
total_pbar.update(1)
# launch tasks
for arg in args:
pool.apply_async(tokenize3, args=arg, callback=_cb)
pool.close()
pool.join()
total_pbar.close()
# merge outputs
merger = './data/output.txt'
with open(merger, 'w') as wf:
for out in outputs:
with open(out, 'r') as rf:
wf.write(rf.read())
os.remove(out)
# aggregate stats
seq_count, rest_count, too_short, too_long, too_manyinstr, discarded_seqs, truncations = map(sum, zip(*results))
rest_ratio = round(100*float(rest_count)/(seq_count*M), 2)
trunc_type = 'duration' #'interarrival' if args.interarrival else 'duration'
trunc_ratio = round(100*float(truncations)/(seq_count*M), 2)
print('Tokenization complete.')
print(f' => Processed {seq_count} training sequences')
print(f' => Inserted {rest_count} REST tokens ({rest_ratio}% of events)')
print(f' => Discarded {too_short+too_long} event sequences')
print(f' - {too_short} too short')
print(f' - {too_long} too long')
print(f' - {too_manyinstr} too many instruments')
print(f' => Discarded {discarded_seqs} training sequences')
print(f' => Truncated {truncations} {trunc_type} times ({trunc_ratio}% of {trunc_type}s)')
print('Remember to shuffle the training split!')
if __name__ == '__main__':
main()