""" utils.py — VAD-based audio segmentation & dataset utilities. Public API ---------- TARGET_SR : int Vad : VAD wrapper AudioDataset : torch Dataset (VAD-segmented chunks) AudioSplitter : context manager — splits large files into M4A chunks collate : DataLoader collate function get_segments : (wav, vad) → list[(start_sample, end_sample)] preprocess_audio: (wav, sr) → mono 16 kHz tensor """ from __future__ import annotations import tempfile from pathlib import Path import _webrtcvad import ffmpeg import numpy as np import torch import torchaudio from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset __all__ = [ "TARGET_SR", "Vad", "AudioDataset", "AudioSplitter", "collate", "get_segments", "preprocess_audio", ] # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- TARGET_SR = 16_000 _VAD_FRAME_LEN = 480 # 30 ms @ 16 kHz — WebRTC requirement _SMOOTH_WIN = 5 _COST_W_SPEECH = 0.7 _COST_W_ENERGY = 0.3 _MIN_SEG_S = 8.0 _DEFAULT_MAX_S = 18.0 _DEFAULT_VAD_MODE = 3 # --------------------------------------------------------------------------- # VAD # --------------------------------------------------------------------------- class Vad: """Thin, stateless-feeling wrapper around native WebRTC VAD.""" def __init__(self, mode: int = _DEFAULT_VAD_MODE) -> None: self._vad = _webrtcvad.create() _webrtcvad.init(self._vad) self.set_mode(mode) def set_mode(self, mode: int) -> None: _webrtcvad.set_mode(self._vad, mode) def is_speech(self, buf: bytes, sr: int = TARGET_SR) -> bool: return _webrtcvad.process(self._vad, sr, buf, len(buf) // 2) # --------------------------------------------------------------------------- # Audio helpers # --------------------------------------------------------------------------- def preprocess_audio(wav: torch.Tensor, sr: int) -> torch.Tensor: """Resample → mono → contiguous 1-D float tensor @ TARGET_SR.""" if sr != TARGET_SR: wav = torchaudio.functional.resample(wav, sr, TARGET_SR) if wav.ndim > 1: wav = wav.mean(0) wav = wav.squeeze() return (wav if wav.ndim == 1 else wav.unsqueeze(0)).contiguous() def _frame_features( w16: np.ndarray, vad: Vad, n: int, ) -> tuple[np.ndarray, np.ndarray]: speech = np.zeros(n) energy = np.zeros(n) for i in range(n): frame = w16[i * _VAD_FRAME_LEN : (i + 1) * _VAD_FRAME_LEN] try: speech[i] = vad.is_speech(frame.tobytes()) except Exception: pass energy[i] = np.sqrt(np.mean(frame.astype(np.float32) ** 2) + 1e-8) mx = energy.max() if mx >= 1e-5: energy /= mx return speech, energy def _cost_signal(speech: np.ndarray, energy: np.ndarray) -> np.ndarray: smooth = np.convolve(speech, np.ones(_SMOOTH_WIN) / _SMOOTH_WIN, mode="same") return _COST_W_SPEECH * smooth + _COST_W_ENERGY * energy def get_segments( wav: torch.Tensor, vad: Vad, max_s: float = _DEFAULT_MAX_S, ) -> list[tuple[int, int]]: """ Segment *wav* into [(start_sample, end_sample), ...] by VAD cost minima. Always returns at least one segment covering the full waveform. """ w16 = (wav.clamp(-1, 1) * 32_767).to(torch.int16).cpu().numpy() n_frames = len(w16) // _VAD_FRAME_LEN total = len(wav) if n_frames == 0: return [(0, total)] speech, energy = _frame_features(w16, vad, n_frames) if energy.max() < 1e-5: # silent file return [(0, total)] cost = _cost_signal(speech, energy) min_frames = int(_MIN_SEG_S / 0.03) max_frames = int(max_s * TARGET_SR / _VAD_FRAME_LEN) segments: list[tuple[int, int]] = [] s = 0 while s < n_frames: if n_frames - s <= max_frames: start, end = s * _VAD_FRAME_LEN, total if start < end: segments.append((start, end)) break lo = s + min_frames hi = min(s + max_frames, n_frames) cut = (lo + int(np.argmin(cost[lo:hi]))) if lo < hi else hi cut = max(cut, s + 1) start, end = s * _VAD_FRAME_LEN, min(cut * _VAD_FRAME_LEN, total) if start < end: segments.append((start, end)) s = cut return segments or [(0, total)] # --------------------------------------------------------------------------- # Dataset # --------------------------------------------------------------------------- class AudioDataset(Dataset): """ VAD-segmented waveform as a torch Dataset. Each item: (segment_tensor, {"start_s": float, "end_s": float}) Timestamps are LOCAL (relative to the start of *wav*). """ def __init__( self, wav: torch.Tensor, sr: int, vad_mode: int = _DEFAULT_VAD_MODE, max_s: float = _DEFAULT_MAX_S, ) -> None: self.wav = preprocess_audio(wav, sr) self.sr = TARGET_SR # Each worker needs its own Vad instance (not thread-safe to share) self.ts = get_segments(self.wav, Vad(vad_mode), max_s) # ------------------------------------------------------------------ def __len__(self) -> int: return len(self.ts) def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]: s, e = self.ts[idx] segment = self.wav[s:e] if segment.numel() == 0: segment = self.wav[:1] return segment, {"start_s": s / self.sr, "end_s": e / self.sr} # Convenience: safe first / last (handles single-segment edge case) @property def first(self) -> tuple[torch.Tensor, dict]: return self[0] @property def last(self) -> tuple[torch.Tensor, dict]: return self[len(self) - 1] @property def is_single_segment(self) -> bool: return len(self) == 1 # --------------------------------------------------------------------------- # Collate # --------------------------------------------------------------------------- def collate(batch: list) -> tuple[torch.Tensor, torch.Tensor, dict]: """Pad variable-length segments; stack metadata tensors.""" audio, metas = zip(*batch) lengths = torch.tensor([x.numel() for x in audio]) padded = pad_sequence(audio, batch_first=True) meta = { "start_s": torch.tensor([m["start_s"] for m in metas]), "end_s": torch.tensor([m["end_s"] for m in metas]), } return padded, lengths, meta # --------------------------------------------------------------------------- # AudioSplitter # --------------------------------------------------------------------------- class AudioSplitter: """ Split a large audio/video file into fixed-duration M4A chunks via FFmpeg. Usage ----- with AudioSplitter("long.mp4", split_min=10) as splitter: for path in splitter.chunks: # sorted list[Path] wav, sr = torchaudio.load(path) ... """ def __init__(self, path: str | Path, split_min: float) -> None: self.path = Path(path) self.split_sec = int(split_min * 60) self.chunks: list[Path] = [] self._tmp = None if not self.path.exists(): raise FileNotFoundError(self.path) def __enter__(self) -> "AudioSplitter": self._tmp = tempfile.TemporaryDirectory() work = Path(self._tmp.name) pattern = str(work / f"{self.path.stem}_chunk_%03d.m4a") try: ( ffmpeg .input(str(self.path)) .output( pattern, format = "segment", segment_time = self.split_sec, reset_timestamps = 1, map = "0:a", acodec = "aac", ar = TARGET_SR, ac = 1, ) .overwrite_output() .run(quiet=True) ) except ffmpeg.Error as e: # Decode stderr for a useful error message stderr = e.stderr.decode(errors="replace") if e.stderr else "(no stderr)" raise RuntimeError( f"FFmpeg failed while splitting '{self.path}':\n{stderr}" ) from e self.chunks = sorted(work.glob("*_chunk_*.m4a")) if not self.chunks: raise RuntimeError(f"FFmpeg produced no chunks from {self.path}") return self def __exit__(self, *_) -> None: if self._tmp: self._tmp.cleanup()