lyric-sync / lyric_sync /refine.py
rikhoffbauer2's picture
Upload lyric_sync/refine.py
1be2d0f verified
"""
Audio-based timing refinement using onset/offset detection.
Refines coarse word timestamps (from ASR alignment) to sub-10ms precision
using signal-domain analysis of the vocals waveform:
1. Onset detection (spectral flux + librosa ODF) → snap word starts
2. RMS energy envelope → find word ends (energy decay)
3. Silence gap detection → refine inter-word boundaries
4. Sanity constraints (minimum duration, no overlaps)
Reference: Standard MIR onset detection (librosa) combined with
forced-alignment-specific refinement heuristics.
"""
import logging
from typing import Optional
import numpy as np
from lyric_sync.transcribe import TimedWord
logger = logging.getLogger(__name__)
class TimingRefiner:
"""
Refine word-level timestamps using audio signal analysis.
Operates on the isolated vocals waveform (post-Demucs separation).
Expects mono float32 audio at 44100 Hz for maximum temporal precision.
"""
def __init__(
self,
sr: int = 44100,
hop_length: int = 256,
onset_search_window_sec: float = 0.08,
offset_search_window_sec: float = 0.05,
silence_threshold_db: float = -45.0,
min_word_duration_sec: float = 0.03,
fmin: float = 80.0,
fmax: float = 4000.0,
):
"""
Args:
sr: Sample rate of input audio (44100 recommended for precision)
hop_length: STFT hop length. 256 at 44100Hz → 5.8ms frame resolution.
onset_search_window_sec: Search window for onset snapping (±this around ASR time)
offset_search_window_sec: Search window for end-of-word refinement
silence_threshold_db: dB below peak RMS to consider "silence"
min_word_duration_sec: Minimum allowed word duration
fmin: Lowest frequency for vocal onset detection (Hz)
fmax: Highest frequency for vocal onset detection (Hz)
"""
self.sr = sr
self.hop_length = hop_length
self.onset_search_window_sec = onset_search_window_sec
self.offset_search_window_sec = offset_search_window_sec
self.silence_threshold_db = silence_threshold_db
self.min_word_duration_sec = min_word_duration_sec
self.fmin = fmin
self.fmax = fmax
def refine(
self,
vocals: np.ndarray,
words: list[TimedWord],
) -> list[TimedWord]:
"""
Refine all word timestamps using audio analysis.
Args:
vocals: Mono float32 numpy array at self.sr Hz
words: Words with coarse timestamps from alignment
Returns:
Words with refined timestamps
"""
import librosa
if len(vocals) == 0 or not words:
return words
# Pre-compute analysis signals
odf = self._compute_onset_envelope(vocals)
rms = self._compute_rms_envelope(vocals)
rms_smooth = self._smooth(rms, window_size=7)
silence_gaps = self._detect_silence_gaps(rms)
onset_frames = self._detect_onsets(odf)
logger.info(
f"Timing refinement: {len(onset_frames)} onsets, "
f"{len(silence_gaps)} silence gaps detected"
)
refined = []
for word in words:
w = TimedWord(
word=word.word,
start=word.start,
end=word.end,
confidence=word.confidence,
)
# Refine start → snap to nearest onset
w.start = self._snap_to_onset(
w.start, onset_frames, odf
)
# Refine end → find energy drop-off
w.end = self._refine_end(w.end, rms_smooth)
# Sanity: end must be after start with minimum duration
if w.end <= w.start + self.min_word_duration_sec:
w.end = w.start + self.min_word_duration_sec
refined.append(w)
# Silence gap snapping (final pass)
refined = self._snap_to_silence_gaps(refined, silence_gaps)
# Ensure no overlaps
refined = self._resolve_overlaps(refined)
return refined
def _compute_onset_envelope(self, y: np.ndarray) -> np.ndarray:
"""Compute onset strength envelope tuned for vocals."""
import librosa
odf = librosa.onset.onset_strength(
y=y,
sr=self.sr,
hop_length=self.hop_length,
n_fft=1024,
fmin=self.fmin,
fmax=self.fmax,
aggregate=np.median,
detrend=True,
center=True,
)
return odf
def _compute_rms_envelope(self, y: np.ndarray) -> np.ndarray:
"""Compute RMS energy per frame."""
import librosa
rms = librosa.feature.rms(
y=y,
frame_length=1024,
hop_length=self.hop_length,
center=True,
)[0]
return rms
def _detect_onsets(self, odf: np.ndarray) -> np.ndarray:
"""Detect all onsets in the onset envelope."""
import librosa
onsets = librosa.onset.onset_detect(
onset_envelope=odf,
sr=self.sr,
hop_length=self.hop_length,
backtrack=True,
units='frames',
pre_max=2,
post_max=2,
pre_avg=2,
post_avg=4,
delta=0.05,
wait=8,
)
return onsets
def _detect_silence_gaps(
self,
rms: np.ndarray,
min_gap_frames: int = 3,
) -> list[tuple[float, float]]:
"""
Detect silence regions in the RMS envelope.
Returns list of (gap_start_sec, gap_end_sec).
"""
import librosa
rms_db = librosa.amplitude_to_db(rms + 1e-10, ref=rms.max() + 1e-10)
is_silent = rms_db < self.silence_threshold_db
gaps = []
in_gap = False
gap_start = 0
for i, silent in enumerate(is_silent):
if silent and not in_gap:
in_gap = True
gap_start = i
elif not silent and in_gap:
if i - gap_start >= min_gap_frames:
t_start = librosa.frames_to_time(gap_start, sr=self.sr, hop_length=self.hop_length)
t_end = librosa.frames_to_time(i, sr=self.sr, hop_length=self.hop_length)
gaps.append((t_start, t_end))
in_gap = False
return gaps
def _snap_to_onset(
self,
approx_time: float,
onset_frames: np.ndarray,
odf: np.ndarray,
) -> float:
"""Snap an approximate word-start to the nearest detected onset."""
import librosa
if len(onset_frames) == 0:
return approx_time
approx_frame = librosa.time_to_frames(
approx_time, sr=self.sr, hop_length=self.hop_length
)
window_frames = int(self.onset_search_window_sec * self.sr / self.hop_length)
# Find onsets within search window
lo = approx_frame - window_frames
hi = approx_frame + window_frames
candidates = onset_frames[(onset_frames >= lo) & (onset_frames <= hi)]
if len(candidates) == 0:
return approx_time
# Pick the onset nearest to the ASR timestamp
nearest_frame = candidates[np.argmin(np.abs(candidates - approx_frame))]
return librosa.frames_to_time(nearest_frame, sr=self.sr, hop_length=self.hop_length)
def _refine_end(self, approx_end: float, rms_smooth: np.ndarray) -> float:
"""Refine word end by finding energy drop-off."""
import librosa
rms_db = librosa.amplitude_to_db(rms_smooth + 1e-10, ref=rms_smooth.max() + 1e-10)
end_frame = librosa.time_to_frames(
approx_end, sr=self.sr, hop_length=self.hop_length
)
search_frames = int(self.offset_search_window_sec * self.sr / self.hop_length)
lo = max(0, end_frame - search_frames)
hi = min(len(rms_db) - 1, end_frame + search_frames)
if lo >= hi:
return approx_end
# Find first frame where energy drops significantly
window_db = rms_db[lo:hi + 1]
threshold = self.silence_threshold_db + 5 # slightly above full silence
silent_frames = np.where(window_db < threshold)[0]
if len(silent_frames) > 0:
# First energy drop in the window
drop_frame = lo + silent_frames[0]
return librosa.frames_to_time(drop_frame, sr=self.sr, hop_length=self.hop_length)
# No clear drop: use energy minimum in window
min_frame = lo + np.argmin(rms_smooth[lo:hi + 1])
return librosa.frames_to_time(min_frame, sr=self.sr, hop_length=self.hop_length)
def _snap_to_silence_gaps(
self,
words: list[TimedWord],
gaps: list[tuple[float, float]],
snap_tolerance: float = 0.04,
) -> list[TimedWord]:
"""Snap word boundaries to nearby silence gaps."""
refined = []
for word in words:
w = TimedWord(
word=word.word,
start=word.start,
end=word.end,
confidence=word.confidence,
)
for gap_start, gap_end in gaps:
# Snap word start to end of gap (sound resumes)
if abs(gap_end - w.start) < snap_tolerance:
w.start = gap_end
# Snap word end to start of gap (sound stops)
if abs(gap_start - w.end) < snap_tolerance:
w.end = gap_start
refined.append(w)
return refined
def _resolve_overlaps(self, words: list[TimedWord]) -> list[TimedWord]:
"""Ensure no word overlaps with the next, maintaining monotonic order."""
for i in range(len(words) - 1):
if words[i].end > words[i + 1].start:
# Split the overlap at the midpoint
mid = (words[i].end + words[i + 1].start) / 2
words[i] = TimedWord(
word=words[i].word,
start=words[i].start,
end=mid,
confidence=words[i].confidence,
)
words[i + 1] = TimedWord(
word=words[i + 1].word,
start=mid,
end=words[i + 1].end,
confidence=words[i + 1].confidence,
)
return words
@staticmethod
def _smooth(arr: np.ndarray, window_size: int = 5) -> np.ndarray:
"""Simple uniform smoothing."""
if window_size <= 1:
return arr
kernel = np.ones(window_size) / window_size
return np.convolve(arr, kernel, mode='same')
def refine_timings(
vocals: np.ndarray,
sr: int,
words: list[TimedWord],
**kwargs,
) -> list[TimedWord]:
"""
Convenience function: refine word timestamps using audio analysis.
Args:
vocals: Mono float32 numpy array (ideally at 44100 Hz)
sr: Sample rate
words: Words with coarse timestamps
**kwargs: Additional args for TimingRefiner
Returns:
Words with refined timestamps
"""
refiner = TimingRefiner(sr=sr, **kwargs)
return refiner.refine(vocals, words)