import json import math import warnings from collections import OrderedDict import librosa import numpy as np import tqdm import pathlib from csv import DictReader, DictWriter import click from get_pitch import get_pitch_parselmouth warns = [] def get_aligned_pitch(wav_path: pathlib.Path, total_secs: float, timestep: float): waveform, _ = librosa.load(wav_path, sr=44100, mono=True) _, f0, _ = get_pitch_parselmouth(waveform, 512, 44100) pitch = librosa.hz_to_midi(f0) if pitch.shape[0] < total_secs / timestep: pad = math.ceil(total_secs / timestep) - pitch.shape[0] pitch = np.pad(pitch, [0, pad], mode='constant', constant_values=[0, pitch[-1]]) return pitch def correct_cents_item( name: str, item: OrderedDict, ref_pitch: np.ndarray, timestep: float, error_ratio: float ): note_seq = item['note_seq'].split() note_dur = [float(d) for d in item['note_dur'].split()] assert len(note_seq) == len(note_dur) start = 0. note_seq_correct = [] for i, (note, dur) in enumerate(zip(note_seq, note_dur)): end = start + dur if note == 'rest': start = end note_seq_correct.append('rest') continue midi = librosa.note_to_midi(note, round_midi=False) start_idx = math.floor(start / timestep) end_idx = math.ceil(end / timestep) note_pitch = ref_pitch[start_idx: end_idx] note_pitch_close = note_pitch[(note_pitch >= midi - 0.5) & (note_pitch < midi + 0.5)] if len(note_pitch_close) < len(note_pitch) * error_ratio or len(note_pitch) == 0: warns.append({ 'position': name, 'note_index': i, 'note_value': note }) if len(note_pitch) == 0 or len(note_pitch_close) == 0: start = end note_seq_correct.append(note) continue midi_correct = np.mean(note_pitch_close) note_seq_correct.append(librosa.midi_to_note(midi_correct, cents=True, unicode=False)) start = end item['note_seq'] = ' '.join(note_seq_correct) def save_warnings(save_dir: pathlib.Path): if len(warns) > 0: save_path = save_dir.resolve() / 'warnings.csv' with open(save_path, 'w', encoding='utf8', newline='') as f: writer = DictWriter(f, fieldnames=['position', 'note_index', 'note_value']) writer.writeheader() writer.writerows(warns) warnings.warn( message=f'possible labeling errors saved in {save_path}', category=UserWarning ) warnings.filterwarnings(action='default') @click.group(help='Apply cents correction to note sequences') def correct_cents(): pass @correct_cents.command(help='Apply cents correction to note sequences in transcriptions.csv') @click.argument('transcriptions', metavar='TRANSCRIPTIONS') @click.argument('waveforms', metavar='WAVS') @click.option('--error_ratio', metavar='RATIO', type=float, default=0.4, help='If the percentage of pitch points within a deviation of 50 cents compared to the note label ' 'is lower than this value, a warning will be raised.') def csv( transcriptions, waveforms, error_ratio ): transcriptions = pathlib.Path(transcriptions).resolve() waveforms = pathlib.Path(waveforms).resolve() with open(transcriptions, 'r', encoding='utf8') as f: reader = DictReader(f) items: list[OrderedDict] = [] for item in reader: items.append(OrderedDict(item)) timestep = 512 / 44100 for item in tqdm.tqdm(items): item: OrderedDict ref_pitch = get_aligned_pitch( wav_path=waveforms / (item['name'] + '.wav'), total_secs=sum(float(d) for d in item['note_dur'].split()), timestep=timestep ) correct_cents_item( name=item['name'], item=item, ref_pitch=ref_pitch, timestep=timestep, error_ratio=error_ratio ) with open(transcriptions, 'w', encoding='utf8', newline='') as f: writer = DictWriter(f, fieldnames=['name', 'ph_seq', 'ph_dur', 'ph_num', 'note_seq', 'note_dur']) writer.writeheader() writer.writerows(items) save_warnings(transcriptions.parent) @correct_cents.command(help='Apply cents correction to note sequences in DS files') @click.argument('ds_dir', metavar='DS_DIR') @click.option('--error_ratio', metavar='RATIO', type=float, default=0.4, help='If the percentage of pitch points within a deviation of 50 cents compared to the note label ' 'is lower than this value, a warning will be raised.') def ds( ds_dir, error_ratio ): ds_dir = pathlib.Path(ds_dir).resolve() assert ds_dir.exists(), 'The directory of DS files does not exist.' timestep = 512 / 44100 for ds_file in tqdm.tqdm(ds_dir.glob('*.ds')): if not ds_file.is_file(): continue assert ds_file.with_suffix('.wav').exists(), \ f'Missing corresponding .wav file of {ds_file.name}.' with open(ds_file, 'r', encoding='utf8') as f: params = json.load(f) if not isinstance(params, list): params = [params] params = [OrderedDict(p) for p in params] ref_pitch = get_aligned_pitch( wav_path=ds_file.with_suffix('.wav'), total_secs=params[-1]['offset'] + sum(float(d) for d in params[-1]['note_dur'].split()), timestep=timestep ) for i, param in enumerate(params): start_idx = math.floor(param['offset'] / timestep) end_idx = math.ceil((param['offset'] + sum(float(d) for d in param['note_dur'].split())) / timestep) correct_cents_item( name=f'{ds_file.stem}#{i}', item=param, ref_pitch=ref_pitch[start_idx: end_idx], timestep=timestep, error_ratio=error_ratio ) with open(ds_file, 'w', encoding='utf8') as f: json.dump(params, f, ensure_ascii=False, indent=2) save_warnings(ds_dir) if __name__ == '__main__': correct_cents()