""" Audio-based timing refinement using onset/offset detection. Refines coarse word timestamps (from ASR alignment) to sub-10ms precision using signal-domain analysis of the vocals waveform: 1. Onset detection (spectral flux + librosa ODF) → snap word starts 2. RMS energy envelope → find word ends (energy decay) 3. Silence gap detection → refine inter-word boundaries 4. Sanity constraints (minimum duration, no overlaps) Reference: Standard MIR onset detection (librosa) combined with forced-alignment-specific refinement heuristics. """ import logging from typing import Optional import numpy as np from lyric_sync.transcribe import TimedWord logger = logging.getLogger(__name__) class TimingRefiner: """ Refine word-level timestamps using audio signal analysis. Operates on the isolated vocals waveform (post-Demucs separation). Expects mono float32 audio at 44100 Hz for maximum temporal precision. """ def __init__( self, sr: int = 44100, hop_length: int = 256, onset_search_window_sec: float = 0.08, offset_search_window_sec: float = 0.05, silence_threshold_db: float = -45.0, min_word_duration_sec: float = 0.03, fmin: float = 80.0, fmax: float = 4000.0, ): """ Args: sr: Sample rate of input audio (44100 recommended for precision) hop_length: STFT hop length. 256 at 44100Hz → 5.8ms frame resolution. onset_search_window_sec: Search window for onset snapping (±this around ASR time) offset_search_window_sec: Search window for end-of-word refinement silence_threshold_db: dB below peak RMS to consider "silence" min_word_duration_sec: Minimum allowed word duration fmin: Lowest frequency for vocal onset detection (Hz) fmax: Highest frequency for vocal onset detection (Hz) """ self.sr = sr self.hop_length = hop_length self.onset_search_window_sec = onset_search_window_sec self.offset_search_window_sec = offset_search_window_sec self.silence_threshold_db = silence_threshold_db self.min_word_duration_sec = min_word_duration_sec self.fmin = fmin self.fmax = fmax def refine( self, vocals: np.ndarray, words: list[TimedWord], ) -> list[TimedWord]: """ Refine all word timestamps using audio analysis. Args: vocals: Mono float32 numpy array at self.sr Hz words: Words with coarse timestamps from alignment Returns: Words with refined timestamps """ import librosa if len(vocals) == 0 or not words: return words # Pre-compute analysis signals odf = self._compute_onset_envelope(vocals) rms = self._compute_rms_envelope(vocals) rms_smooth = self._smooth(rms, window_size=7) silence_gaps = self._detect_silence_gaps(rms) onset_frames = self._detect_onsets(odf) logger.info( f"Timing refinement: {len(onset_frames)} onsets, " f"{len(silence_gaps)} silence gaps detected" ) refined = [] for word in words: w = TimedWord( word=word.word, start=word.start, end=word.end, confidence=word.confidence, ) # Refine start → snap to nearest onset w.start = self._snap_to_onset( w.start, onset_frames, odf ) # Refine end → find energy drop-off w.end = self._refine_end(w.end, rms_smooth) # Sanity: end must be after start with minimum duration if w.end <= w.start + self.min_word_duration_sec: w.end = w.start + self.min_word_duration_sec refined.append(w) # Silence gap snapping (final pass) refined = self._snap_to_silence_gaps(refined, silence_gaps) # Ensure no overlaps refined = self._resolve_overlaps(refined) return refined def _compute_onset_envelope(self, y: np.ndarray) -> np.ndarray: """Compute onset strength envelope tuned for vocals.""" import librosa odf = librosa.onset.onset_strength( y=y, sr=self.sr, hop_length=self.hop_length, n_fft=1024, fmin=self.fmin, fmax=self.fmax, aggregate=np.median, detrend=True, center=True, ) return odf def _compute_rms_envelope(self, y: np.ndarray) -> np.ndarray: """Compute RMS energy per frame.""" import librosa rms = librosa.feature.rms( y=y, frame_length=1024, hop_length=self.hop_length, center=True, )[0] return rms def _detect_onsets(self, odf: np.ndarray) -> np.ndarray: """Detect all onsets in the onset envelope.""" import librosa onsets = librosa.onset.onset_detect( onset_envelope=odf, sr=self.sr, hop_length=self.hop_length, backtrack=True, units='frames', pre_max=2, post_max=2, pre_avg=2, post_avg=4, delta=0.05, wait=8, ) return onsets def _detect_silence_gaps( self, rms: np.ndarray, min_gap_frames: int = 3, ) -> list[tuple[float, float]]: """ Detect silence regions in the RMS envelope. Returns list of (gap_start_sec, gap_end_sec). """ import librosa rms_db = librosa.amplitude_to_db(rms + 1e-10, ref=rms.max() + 1e-10) is_silent = rms_db < self.silence_threshold_db gaps = [] in_gap = False gap_start = 0 for i, silent in enumerate(is_silent): if silent and not in_gap: in_gap = True gap_start = i elif not silent and in_gap: if i - gap_start >= min_gap_frames: t_start = librosa.frames_to_time(gap_start, sr=self.sr, hop_length=self.hop_length) t_end = librosa.frames_to_time(i, sr=self.sr, hop_length=self.hop_length) gaps.append((t_start, t_end)) in_gap = False return gaps def _snap_to_onset( self, approx_time: float, onset_frames: np.ndarray, odf: np.ndarray, ) -> float: """Snap an approximate word-start to the nearest detected onset.""" import librosa if len(onset_frames) == 0: return approx_time approx_frame = librosa.time_to_frames( approx_time, sr=self.sr, hop_length=self.hop_length ) window_frames = int(self.onset_search_window_sec * self.sr / self.hop_length) # Find onsets within search window lo = approx_frame - window_frames hi = approx_frame + window_frames candidates = onset_frames[(onset_frames >= lo) & (onset_frames <= hi)] if len(candidates) == 0: return approx_time # Pick the onset nearest to the ASR timestamp nearest_frame = candidates[np.argmin(np.abs(candidates - approx_frame))] return librosa.frames_to_time(nearest_frame, sr=self.sr, hop_length=self.hop_length) def _refine_end(self, approx_end: float, rms_smooth: np.ndarray) -> float: """Refine word end by finding energy drop-off.""" import librosa rms_db = librosa.amplitude_to_db(rms_smooth + 1e-10, ref=rms_smooth.max() + 1e-10) end_frame = librosa.time_to_frames( approx_end, sr=self.sr, hop_length=self.hop_length ) search_frames = int(self.offset_search_window_sec * self.sr / self.hop_length) lo = max(0, end_frame - search_frames) hi = min(len(rms_db) - 1, end_frame + search_frames) if lo >= hi: return approx_end # Find first frame where energy drops significantly window_db = rms_db[lo:hi + 1] threshold = self.silence_threshold_db + 5 # slightly above full silence silent_frames = np.where(window_db < threshold)[0] if len(silent_frames) > 0: # First energy drop in the window drop_frame = lo + silent_frames[0] return librosa.frames_to_time(drop_frame, sr=self.sr, hop_length=self.hop_length) # No clear drop: use energy minimum in window min_frame = lo + np.argmin(rms_smooth[lo:hi + 1]) return librosa.frames_to_time(min_frame, sr=self.sr, hop_length=self.hop_length) def _snap_to_silence_gaps( self, words: list[TimedWord], gaps: list[tuple[float, float]], snap_tolerance: float = 0.04, ) -> list[TimedWord]: """Snap word boundaries to nearby silence gaps.""" refined = [] for word in words: w = TimedWord( word=word.word, start=word.start, end=word.end, confidence=word.confidence, ) for gap_start, gap_end in gaps: # Snap word start to end of gap (sound resumes) if abs(gap_end - w.start) < snap_tolerance: w.start = gap_end # Snap word end to start of gap (sound stops) if abs(gap_start - w.end) < snap_tolerance: w.end = gap_start refined.append(w) return refined def _resolve_overlaps(self, words: list[TimedWord]) -> list[TimedWord]: """Ensure no word overlaps with the next, maintaining monotonic order.""" for i in range(len(words) - 1): if words[i].end > words[i + 1].start: # Split the overlap at the midpoint mid = (words[i].end + words[i + 1].start) / 2 words[i] = TimedWord( word=words[i].word, start=words[i].start, end=mid, confidence=words[i].confidence, ) words[i + 1] = TimedWord( word=words[i + 1].word, start=mid, end=words[i + 1].end, confidence=words[i + 1].confidence, ) return words @staticmethod def _smooth(arr: np.ndarray, window_size: int = 5) -> np.ndarray: """Simple uniform smoothing.""" if window_size <= 1: return arr kernel = np.ones(window_size) / window_size return np.convolve(arr, kernel, mode='same') def refine_timings( vocals: np.ndarray, sr: int, words: list[TimedWord], **kwargs, ) -> list[TimedWord]: """ Convenience function: refine word timestamps using audio analysis. Args: vocals: Mono float32 numpy array (ideally at 44100 Hz) sr: Sample rate words: Words with coarse timestamps **kwargs: Additional args for TimingRefiner Returns: Words with refined timestamps """ refiner = TimingRefiner(sr=sr, **kwargs) return refiner.refine(vocals, words)