rikhoffbauer2 commited on
Commit
1be2d0f
·
verified ·
1 Parent(s): 63c7ca6

Upload lyric_sync/refine.py

Browse files
Files changed (1) hide show
  1. lyric_sync/refine.py +340 -0
lyric_sync/refine.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio-based timing refinement using onset/offset detection.
3
+
4
+ Refines coarse word timestamps (from ASR alignment) to sub-10ms precision
5
+ using signal-domain analysis of the vocals waveform:
6
+
7
+ 1. Onset detection (spectral flux + librosa ODF) → snap word starts
8
+ 2. RMS energy envelope → find word ends (energy decay)
9
+ 3. Silence gap detection → refine inter-word boundaries
10
+ 4. Sanity constraints (minimum duration, no overlaps)
11
+
12
+ Reference: Standard MIR onset detection (librosa) combined with
13
+ forced-alignment-specific refinement heuristics.
14
+ """
15
+
16
+ import logging
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+
21
+ from lyric_sync.transcribe import TimedWord
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class TimingRefiner:
27
+ """
28
+ Refine word-level timestamps using audio signal analysis.
29
+
30
+ Operates on the isolated vocals waveform (post-Demucs separation).
31
+ Expects mono float32 audio at 44100 Hz for maximum temporal precision.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ sr: int = 44100,
37
+ hop_length: int = 256,
38
+ onset_search_window_sec: float = 0.08,
39
+ offset_search_window_sec: float = 0.05,
40
+ silence_threshold_db: float = -45.0,
41
+ min_word_duration_sec: float = 0.03,
42
+ fmin: float = 80.0,
43
+ fmax: float = 4000.0,
44
+ ):
45
+ """
46
+ Args:
47
+ sr: Sample rate of input audio (44100 recommended for precision)
48
+ hop_length: STFT hop length. 256 at 44100Hz → 5.8ms frame resolution.
49
+ onset_search_window_sec: Search window for onset snapping (±this around ASR time)
50
+ offset_search_window_sec: Search window for end-of-word refinement
51
+ silence_threshold_db: dB below peak RMS to consider "silence"
52
+ min_word_duration_sec: Minimum allowed word duration
53
+ fmin: Lowest frequency for vocal onset detection (Hz)
54
+ fmax: Highest frequency for vocal onset detection (Hz)
55
+ """
56
+ self.sr = sr
57
+ self.hop_length = hop_length
58
+ self.onset_search_window_sec = onset_search_window_sec
59
+ self.offset_search_window_sec = offset_search_window_sec
60
+ self.silence_threshold_db = silence_threshold_db
61
+ self.min_word_duration_sec = min_word_duration_sec
62
+ self.fmin = fmin
63
+ self.fmax = fmax
64
+
65
+ def refine(
66
+ self,
67
+ vocals: np.ndarray,
68
+ words: list[TimedWord],
69
+ ) -> list[TimedWord]:
70
+ """
71
+ Refine all word timestamps using audio analysis.
72
+
73
+ Args:
74
+ vocals: Mono float32 numpy array at self.sr Hz
75
+ words: Words with coarse timestamps from alignment
76
+
77
+ Returns:
78
+ Words with refined timestamps
79
+ """
80
+ import librosa
81
+
82
+ if len(vocals) == 0 or not words:
83
+ return words
84
+
85
+ # Pre-compute analysis signals
86
+ odf = self._compute_onset_envelope(vocals)
87
+ rms = self._compute_rms_envelope(vocals)
88
+ rms_smooth = self._smooth(rms, window_size=7)
89
+ silence_gaps = self._detect_silence_gaps(rms)
90
+ onset_frames = self._detect_onsets(odf)
91
+
92
+ logger.info(
93
+ f"Timing refinement: {len(onset_frames)} onsets, "
94
+ f"{len(silence_gaps)} silence gaps detected"
95
+ )
96
+
97
+ refined = []
98
+ for word in words:
99
+ w = TimedWord(
100
+ word=word.word,
101
+ start=word.start,
102
+ end=word.end,
103
+ confidence=word.confidence,
104
+ )
105
+
106
+ # Refine start → snap to nearest onset
107
+ w.start = self._snap_to_onset(
108
+ w.start, onset_frames, odf
109
+ )
110
+
111
+ # Refine end → find energy drop-off
112
+ w.end = self._refine_end(w.end, rms_smooth)
113
+
114
+ # Sanity: end must be after start with minimum duration
115
+ if w.end <= w.start + self.min_word_duration_sec:
116
+ w.end = w.start + self.min_word_duration_sec
117
+
118
+ refined.append(w)
119
+
120
+ # Silence gap snapping (final pass)
121
+ refined = self._snap_to_silence_gaps(refined, silence_gaps)
122
+
123
+ # Ensure no overlaps
124
+ refined = self._resolve_overlaps(refined)
125
+
126
+ return refined
127
+
128
+ def _compute_onset_envelope(self, y: np.ndarray) -> np.ndarray:
129
+ """Compute onset strength envelope tuned for vocals."""
130
+ import librosa
131
+
132
+ odf = librosa.onset.onset_strength(
133
+ y=y,
134
+ sr=self.sr,
135
+ hop_length=self.hop_length,
136
+ n_fft=1024,
137
+ fmin=self.fmin,
138
+ fmax=self.fmax,
139
+ aggregate=np.median,
140
+ detrend=True,
141
+ center=True,
142
+ )
143
+ return odf
144
+
145
+ def _compute_rms_envelope(self, y: np.ndarray) -> np.ndarray:
146
+ """Compute RMS energy per frame."""
147
+ import librosa
148
+
149
+ rms = librosa.feature.rms(
150
+ y=y,
151
+ frame_length=1024,
152
+ hop_length=self.hop_length,
153
+ center=True,
154
+ )[0]
155
+ return rms
156
+
157
+ def _detect_onsets(self, odf: np.ndarray) -> np.ndarray:
158
+ """Detect all onsets in the onset envelope."""
159
+ import librosa
160
+
161
+ onsets = librosa.onset.onset_detect(
162
+ onset_envelope=odf,
163
+ sr=self.sr,
164
+ hop_length=self.hop_length,
165
+ backtrack=True,
166
+ units='frames',
167
+ pre_max=2,
168
+ post_max=2,
169
+ pre_avg=2,
170
+ post_avg=4,
171
+ delta=0.05,
172
+ wait=8,
173
+ )
174
+ return onsets
175
+
176
+ def _detect_silence_gaps(
177
+ self,
178
+ rms: np.ndarray,
179
+ min_gap_frames: int = 3,
180
+ ) -> list[tuple[float, float]]:
181
+ """
182
+ Detect silence regions in the RMS envelope.
183
+ Returns list of (gap_start_sec, gap_end_sec).
184
+ """
185
+ import librosa
186
+
187
+ rms_db = librosa.amplitude_to_db(rms + 1e-10, ref=rms.max() + 1e-10)
188
+ is_silent = rms_db < self.silence_threshold_db
189
+
190
+ gaps = []
191
+ in_gap = False
192
+ gap_start = 0
193
+
194
+ for i, silent in enumerate(is_silent):
195
+ if silent and not in_gap:
196
+ in_gap = True
197
+ gap_start = i
198
+ elif not silent and in_gap:
199
+ if i - gap_start >= min_gap_frames:
200
+ t_start = librosa.frames_to_time(gap_start, sr=self.sr, hop_length=self.hop_length)
201
+ t_end = librosa.frames_to_time(i, sr=self.sr, hop_length=self.hop_length)
202
+ gaps.append((t_start, t_end))
203
+ in_gap = False
204
+
205
+ return gaps
206
+
207
+ def _snap_to_onset(
208
+ self,
209
+ approx_time: float,
210
+ onset_frames: np.ndarray,
211
+ odf: np.ndarray,
212
+ ) -> float:
213
+ """Snap an approximate word-start to the nearest detected onset."""
214
+ import librosa
215
+
216
+ if len(onset_frames) == 0:
217
+ return approx_time
218
+
219
+ approx_frame = librosa.time_to_frames(
220
+ approx_time, sr=self.sr, hop_length=self.hop_length
221
+ )
222
+ window_frames = int(self.onset_search_window_sec * self.sr / self.hop_length)
223
+
224
+ # Find onsets within search window
225
+ lo = approx_frame - window_frames
226
+ hi = approx_frame + window_frames
227
+ candidates = onset_frames[(onset_frames >= lo) & (onset_frames <= hi)]
228
+
229
+ if len(candidates) == 0:
230
+ return approx_time
231
+
232
+ # Pick the onset nearest to the ASR timestamp
233
+ nearest_frame = candidates[np.argmin(np.abs(candidates - approx_frame))]
234
+ return librosa.frames_to_time(nearest_frame, sr=self.sr, hop_length=self.hop_length)
235
+
236
+ def _refine_end(self, approx_end: float, rms_smooth: np.ndarray) -> float:
237
+ """Refine word end by finding energy drop-off."""
238
+ import librosa
239
+
240
+ rms_db = librosa.amplitude_to_db(rms_smooth + 1e-10, ref=rms_smooth.max() + 1e-10)
241
+
242
+ end_frame = librosa.time_to_frames(
243
+ approx_end, sr=self.sr, hop_length=self.hop_length
244
+ )
245
+ search_frames = int(self.offset_search_window_sec * self.sr / self.hop_length)
246
+
247
+ lo = max(0, end_frame - search_frames)
248
+ hi = min(len(rms_db) - 1, end_frame + search_frames)
249
+
250
+ if lo >= hi:
251
+ return approx_end
252
+
253
+ # Find first frame where energy drops significantly
254
+ window_db = rms_db[lo:hi + 1]
255
+ threshold = self.silence_threshold_db + 5 # slightly above full silence
256
+
257
+ silent_frames = np.where(window_db < threshold)[0]
258
+ if len(silent_frames) > 0:
259
+ # First energy drop in the window
260
+ drop_frame = lo + silent_frames[0]
261
+ return librosa.frames_to_time(drop_frame, sr=self.sr, hop_length=self.hop_length)
262
+
263
+ # No clear drop: use energy minimum in window
264
+ min_frame = lo + np.argmin(rms_smooth[lo:hi + 1])
265
+ return librosa.frames_to_time(min_frame, sr=self.sr, hop_length=self.hop_length)
266
+
267
+ def _snap_to_silence_gaps(
268
+ self,
269
+ words: list[TimedWord],
270
+ gaps: list[tuple[float, float]],
271
+ snap_tolerance: float = 0.04,
272
+ ) -> list[TimedWord]:
273
+ """Snap word boundaries to nearby silence gaps."""
274
+ refined = []
275
+ for word in words:
276
+ w = TimedWord(
277
+ word=word.word,
278
+ start=word.start,
279
+ end=word.end,
280
+ confidence=word.confidence,
281
+ )
282
+ for gap_start, gap_end in gaps:
283
+ # Snap word start to end of gap (sound resumes)
284
+ if abs(gap_end - w.start) < snap_tolerance:
285
+ w.start = gap_end
286
+ # Snap word end to start of gap (sound stops)
287
+ if abs(gap_start - w.end) < snap_tolerance:
288
+ w.end = gap_start
289
+ refined.append(w)
290
+ return refined
291
+
292
+ def _resolve_overlaps(self, words: list[TimedWord]) -> list[TimedWord]:
293
+ """Ensure no word overlaps with the next, maintaining monotonic order."""
294
+ for i in range(len(words) - 1):
295
+ if words[i].end > words[i + 1].start:
296
+ # Split the overlap at the midpoint
297
+ mid = (words[i].end + words[i + 1].start) / 2
298
+ words[i] = TimedWord(
299
+ word=words[i].word,
300
+ start=words[i].start,
301
+ end=mid,
302
+ confidence=words[i].confidence,
303
+ )
304
+ words[i + 1] = TimedWord(
305
+ word=words[i + 1].word,
306
+ start=mid,
307
+ end=words[i + 1].end,
308
+ confidence=words[i + 1].confidence,
309
+ )
310
+ return words
311
+
312
+ @staticmethod
313
+ def _smooth(arr: np.ndarray, window_size: int = 5) -> np.ndarray:
314
+ """Simple uniform smoothing."""
315
+ if window_size <= 1:
316
+ return arr
317
+ kernel = np.ones(window_size) / window_size
318
+ return np.convolve(arr, kernel, mode='same')
319
+
320
+
321
+ def refine_timings(
322
+ vocals: np.ndarray,
323
+ sr: int,
324
+ words: list[TimedWord],
325
+ **kwargs,
326
+ ) -> list[TimedWord]:
327
+ """
328
+ Convenience function: refine word timestamps using audio analysis.
329
+
330
+ Args:
331
+ vocals: Mono float32 numpy array (ideally at 44100 Hz)
332
+ sr: Sample rate
333
+ words: Words with coarse timestamps
334
+ **kwargs: Additional args for TimingRefiner
335
+
336
+ Returns:
337
+ Words with refined timestamps
338
+ """
339
+ refiner = TimingRefiner(sr=sr, **kwargs)
340
+ return refiner.refine(vocals, words)