File size: 6,217 Bytes
79cf5f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import json
import math
import warnings
from collections import OrderedDict

import librosa
import numpy as np
import tqdm
import pathlib
from csv import DictReader, DictWriter

import click

from get_pitch import get_pitch_parselmouth

warns = []


def get_aligned_pitch(wav_path: pathlib.Path, total_secs: float, timestep: float):
    waveform, _ = librosa.load(wav_path, sr=44100, mono=True)
    _, f0, _ = get_pitch_parselmouth(waveform, 512, 44100)
    pitch = librosa.hz_to_midi(f0)
    if pitch.shape[0] < total_secs / timestep:
        pad = math.ceil(total_secs / timestep) - pitch.shape[0]
        pitch = np.pad(pitch, [0, pad], mode='constant', constant_values=[0, pitch[-1]])
    return pitch


def correct_cents_item(
        name: str, item: OrderedDict, ref_pitch: np.ndarray,
        timestep: float, error_ratio: float
):
    note_seq = item['note_seq'].split()
    note_dur = [float(d) for d in item['note_dur'].split()]
    assert len(note_seq) == len(note_dur)

    start = 0.
    note_seq_correct = []
    for i, (note, dur) in enumerate(zip(note_seq, note_dur)):
        end = start + dur
        if note == 'rest':
            start = end
            note_seq_correct.append('rest')
            continue

        midi = librosa.note_to_midi(note, round_midi=False)
        start_idx = math.floor(start / timestep)
        end_idx = math.ceil(end / timestep)
        note_pitch = ref_pitch[start_idx: end_idx]
        note_pitch_close = note_pitch[(note_pitch >= midi - 0.5) & (note_pitch < midi + 0.5)]
        if len(note_pitch_close) < len(note_pitch) * error_ratio or len(note_pitch) == 0:
            warns.append({
                'position': name,
                'note_index': i,
                'note_value': note
            })
            if len(note_pitch) == 0 or len(note_pitch_close) == 0:
                start = end
                note_seq_correct.append(note)
                continue
        midi_correct = np.mean(note_pitch_close)
        note_seq_correct.append(librosa.midi_to_note(midi_correct, cents=True, unicode=False))

        start = end

    item['note_seq'] = ' '.join(note_seq_correct)


def save_warnings(save_dir: pathlib.Path):
    if len(warns) > 0:
        save_path = save_dir.resolve() / 'warnings.csv'
        with open(save_path, 'w', encoding='utf8', newline='') as f:
            writer = DictWriter(f, fieldnames=['position', 'note_index', 'note_value'])
            writer.writeheader()
            writer.writerows(warns)
        warnings.warn(
            message=f'possible labeling errors saved in {save_path}',
            category=UserWarning
        )
        warnings.filterwarnings(action='default')


@click.group(help='Apply cents correction to note sequences')
def correct_cents():
    pass


@correct_cents.command(help='Apply cents correction to note sequences in transcriptions.csv')
@click.argument('transcriptions', metavar='TRANSCRIPTIONS')
@click.argument('waveforms', metavar='WAVS')
@click.option('--error_ratio', metavar='RATIO', type=float, default=0.4,
              help='If the percentage of pitch points within a deviation of 50 cents compared to the note label '
                   'is lower than this value, a warning will be raised.')
def csv(
        transcriptions,
        waveforms,
        error_ratio
):
    transcriptions = pathlib.Path(transcriptions).resolve()
    waveforms = pathlib.Path(waveforms).resolve()
    with open(transcriptions, 'r', encoding='utf8') as f:
        reader = DictReader(f)
        items: list[OrderedDict] = []
        for item in reader:
            items.append(OrderedDict(item))

    timestep = 512 / 44100
    for item in tqdm.tqdm(items):
        item: OrderedDict
        ref_pitch = get_aligned_pitch(
            wav_path=waveforms / (item['name'] + '.wav'),
            total_secs=sum(float(d) for d in item['note_dur'].split()),
            timestep=timestep
        )
        correct_cents_item(
            name=item['name'], item=item, ref_pitch=ref_pitch,
            timestep=timestep, error_ratio=error_ratio
        )

    with open(transcriptions, 'w', encoding='utf8', newline='') as f:
        writer = DictWriter(f, fieldnames=['name', 'ph_seq', 'ph_dur', 'ph_num', 'note_seq', 'note_dur'])
        writer.writeheader()
        writer.writerows(items)
    save_warnings(transcriptions.parent)


@correct_cents.command(help='Apply cents correction to note sequences in DS files')
@click.argument('ds_dir', metavar='DS_DIR')
@click.option('--error_ratio', metavar='RATIO', type=float, default=0.4,
              help='If the percentage of pitch points within a deviation of 50 cents compared to the note label '
                   'is lower than this value, a warning will be raised.')
def ds(
        ds_dir,
        error_ratio
):
    ds_dir = pathlib.Path(ds_dir).resolve()
    assert ds_dir.exists(), 'The directory of DS files does not exist.'

    timestep = 512 / 44100
    for ds_file in tqdm.tqdm(ds_dir.glob('*.ds')):
        if not ds_file.is_file():
            continue

        assert ds_file.with_suffix('.wav').exists(), \
            f'Missing corresponding .wav file of {ds_file.name}.'
        with open(ds_file, 'r', encoding='utf8') as f:
            params = json.load(f)
        if not isinstance(params, list):
            params = [params]
        params = [OrderedDict(p) for p in params]

        ref_pitch = get_aligned_pitch(
            wav_path=ds_file.with_suffix('.wav'),
            total_secs=params[-1]['offset'] + sum(float(d) for d in params[-1]['note_dur'].split()),
            timestep=timestep
        )
        for i, param in enumerate(params):
            start_idx = math.floor(param['offset'] / timestep)
            end_idx = math.ceil((param['offset'] + sum(float(d) for d in param['note_dur'].split())) / timestep)
            correct_cents_item(
                name=f'{ds_file.stem}#{i}', item=param, ref_pitch=ref_pitch[start_idx: end_idx],
                timestep=timestep, error_ratio=error_ratio
            )

        with open(ds_file, 'w', encoding='utf8') as f:
            json.dump(params, f, ensure_ascii=False, indent=2)
    save_warnings(ds_dir)


if __name__ == '__main__':
    correct_cents()