|
|
import json |
|
|
import math |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
|
|
|
import librosa |
|
|
import numpy as np |
|
|
import tqdm |
|
|
import pathlib |
|
|
from csv import DictReader, DictWriter |
|
|
|
|
|
import click |
|
|
|
|
|
from get_pitch import get_pitch_parselmouth |
|
|
|
|
|
warns = [] |
|
|
|
|
|
|
|
|
def get_aligned_pitch(wav_path: pathlib.Path, total_secs: float, timestep: float): |
|
|
waveform, _ = librosa.load(wav_path, sr=44100, mono=True) |
|
|
_, f0, _ = get_pitch_parselmouth(waveform, 512, 44100) |
|
|
pitch = librosa.hz_to_midi(f0) |
|
|
if pitch.shape[0] < total_secs / timestep: |
|
|
pad = math.ceil(total_secs / timestep) - pitch.shape[0] |
|
|
pitch = np.pad(pitch, [0, pad], mode='constant', constant_values=[0, pitch[-1]]) |
|
|
return pitch |
|
|
|
|
|
|
|
|
def correct_cents_item( |
|
|
name: str, item: OrderedDict, ref_pitch: np.ndarray, |
|
|
timestep: float, error_ratio: float |
|
|
): |
|
|
note_seq = item['note_seq'].split() |
|
|
note_dur = [float(d) for d in item['note_dur'].split()] |
|
|
assert len(note_seq) == len(note_dur) |
|
|
|
|
|
start = 0. |
|
|
note_seq_correct = [] |
|
|
for i, (note, dur) in enumerate(zip(note_seq, note_dur)): |
|
|
end = start + dur |
|
|
if note == 'rest': |
|
|
start = end |
|
|
note_seq_correct.append('rest') |
|
|
continue |
|
|
|
|
|
midi = librosa.note_to_midi(note, round_midi=False) |
|
|
start_idx = math.floor(start / timestep) |
|
|
end_idx = math.ceil(end / timestep) |
|
|
note_pitch = ref_pitch[start_idx: end_idx] |
|
|
note_pitch_close = note_pitch[(note_pitch >= midi - 0.5) & (note_pitch < midi + 0.5)] |
|
|
if len(note_pitch_close) < len(note_pitch) * error_ratio or len(note_pitch) == 0: |
|
|
warns.append({ |
|
|
'position': name, |
|
|
'note_index': i, |
|
|
'note_value': note |
|
|
}) |
|
|
if len(note_pitch) == 0 or len(note_pitch_close) == 0: |
|
|
start = end |
|
|
note_seq_correct.append(note) |
|
|
continue |
|
|
midi_correct = np.mean(note_pitch_close) |
|
|
note_seq_correct.append(librosa.midi_to_note(midi_correct, cents=True, unicode=False)) |
|
|
|
|
|
start = end |
|
|
|
|
|
item['note_seq'] = ' '.join(note_seq_correct) |
|
|
|
|
|
|
|
|
def save_warnings(save_dir: pathlib.Path): |
|
|
if len(warns) > 0: |
|
|
save_path = save_dir.resolve() / 'warnings.csv' |
|
|
with open(save_path, 'w', encoding='utf8', newline='') as f: |
|
|
writer = DictWriter(f, fieldnames=['position', 'note_index', 'note_value']) |
|
|
writer.writeheader() |
|
|
writer.writerows(warns) |
|
|
warnings.warn( |
|
|
message=f'possible labeling errors saved in {save_path}', |
|
|
category=UserWarning |
|
|
) |
|
|
warnings.filterwarnings(action='default') |
|
|
|
|
|
|
|
|
@click.group(help='Apply cents correction to note sequences') |
|
|
def correct_cents(): |
|
|
pass |
|
|
|
|
|
|
|
|
@correct_cents.command(help='Apply cents correction to note sequences in transcriptions.csv') |
|
|
@click.argument('transcriptions', metavar='TRANSCRIPTIONS') |
|
|
@click.argument('waveforms', metavar='WAVS') |
|
|
@click.option('--error_ratio', metavar='RATIO', type=float, default=0.4, |
|
|
help='If the percentage of pitch points within a deviation of 50 cents compared to the note label ' |
|
|
'is lower than this value, a warning will be raised.') |
|
|
def csv( |
|
|
transcriptions, |
|
|
waveforms, |
|
|
error_ratio |
|
|
): |
|
|
transcriptions = pathlib.Path(transcriptions).resolve() |
|
|
waveforms = pathlib.Path(waveforms).resolve() |
|
|
with open(transcriptions, 'r', encoding='utf8') as f: |
|
|
reader = DictReader(f) |
|
|
items: list[OrderedDict] = [] |
|
|
for item in reader: |
|
|
items.append(OrderedDict(item)) |
|
|
|
|
|
timestep = 512 / 44100 |
|
|
for item in tqdm.tqdm(items): |
|
|
item: OrderedDict |
|
|
ref_pitch = get_aligned_pitch( |
|
|
wav_path=waveforms / (item['name'] + '.wav'), |
|
|
total_secs=sum(float(d) for d in item['note_dur'].split()), |
|
|
timestep=timestep |
|
|
) |
|
|
correct_cents_item( |
|
|
name=item['name'], item=item, ref_pitch=ref_pitch, |
|
|
timestep=timestep, error_ratio=error_ratio |
|
|
) |
|
|
|
|
|
with open(transcriptions, 'w', encoding='utf8', newline='') as f: |
|
|
writer = DictWriter(f, fieldnames=['name', 'ph_seq', 'ph_dur', 'ph_num', 'note_seq', 'note_dur']) |
|
|
writer.writeheader() |
|
|
writer.writerows(items) |
|
|
save_warnings(transcriptions.parent) |
|
|
|
|
|
|
|
|
@correct_cents.command(help='Apply cents correction to note sequences in DS files') |
|
|
@click.argument('ds_dir', metavar='DS_DIR') |
|
|
@click.option('--error_ratio', metavar='RATIO', type=float, default=0.4, |
|
|
help='If the percentage of pitch points within a deviation of 50 cents compared to the note label ' |
|
|
'is lower than this value, a warning will be raised.') |
|
|
def ds( |
|
|
ds_dir, |
|
|
error_ratio |
|
|
): |
|
|
ds_dir = pathlib.Path(ds_dir).resolve() |
|
|
assert ds_dir.exists(), 'The directory of DS files does not exist.' |
|
|
|
|
|
timestep = 512 / 44100 |
|
|
for ds_file in tqdm.tqdm(ds_dir.glob('*.ds')): |
|
|
if not ds_file.is_file(): |
|
|
continue |
|
|
|
|
|
assert ds_file.with_suffix('.wav').exists(), \ |
|
|
f'Missing corresponding .wav file of {ds_file.name}.' |
|
|
with open(ds_file, 'r', encoding='utf8') as f: |
|
|
params = json.load(f) |
|
|
if not isinstance(params, list): |
|
|
params = [params] |
|
|
params = [OrderedDict(p) for p in params] |
|
|
|
|
|
ref_pitch = get_aligned_pitch( |
|
|
wav_path=ds_file.with_suffix('.wav'), |
|
|
total_secs=params[-1]['offset'] + sum(float(d) for d in params[-1]['note_dur'].split()), |
|
|
timestep=timestep |
|
|
) |
|
|
for i, param in enumerate(params): |
|
|
start_idx = math.floor(param['offset'] / timestep) |
|
|
end_idx = math.ceil((param['offset'] + sum(float(d) for d in param['note_dur'].split())) / timestep) |
|
|
correct_cents_item( |
|
|
name=f'{ds_file.stem}#{i}', item=param, ref_pitch=ref_pitch[start_idx: end_idx], |
|
|
timestep=timestep, error_ratio=error_ratio |
|
|
) |
|
|
|
|
|
with open(ds_file, 'w', encoding='utf8') as f: |
|
|
json.dump(params, f, ensure_ascii=False, indent=2) |
|
|
save_warnings(ds_dir) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
correct_cents() |
|
|
|