Spaces:
Paused
Paused
| """ | |
| utils.py β VAD-based audio segmentation & dataset utilities. | |
| Public API | |
| ---------- | |
| TARGET_SR : int | |
| Vad : VAD wrapper | |
| AudioDataset : torch Dataset (VAD-segmented chunks) | |
| AudioSplitter : context manager β splits large files into M4A chunks | |
| collate : DataLoader collate function | |
| get_segments : (wav, vad) β list[(start_sample, end_sample)] | |
| preprocess_audio: (wav, sr) β mono 16 kHz tensor | |
| """ | |
| from __future__ import annotations | |
| import tempfile | |
| from pathlib import Path | |
| import _webrtcvad | |
| import ffmpeg | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from torch.nn.utils.rnn import pad_sequence | |
| from torch.utils.data import Dataset | |
| __all__ = [ | |
| "TARGET_SR", | |
| "Vad", | |
| "AudioDataset", | |
| "AudioSplitter", | |
| "collate", | |
| "get_segments", | |
| "preprocess_audio", | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| TARGET_SR = 16_000 | |
| _VAD_FRAME_LEN = 480 # 30 ms @ 16 kHz β WebRTC requirement | |
| _SMOOTH_WIN = 5 | |
| _COST_W_SPEECH = 0.7 | |
| _COST_W_ENERGY = 0.3 | |
| _MIN_SEG_S = 8.0 | |
| _DEFAULT_MAX_S = 18.0 | |
| _DEFAULT_VAD_MODE = 3 | |
| # --------------------------------------------------------------------------- | |
| # VAD | |
| # --------------------------------------------------------------------------- | |
| class Vad: | |
| """Thin, stateless-feeling wrapper around native WebRTC VAD.""" | |
| def __init__(self, mode: int = _DEFAULT_VAD_MODE) -> None: | |
| self._vad = _webrtcvad.create() | |
| _webrtcvad.init(self._vad) | |
| self.set_mode(mode) | |
| def set_mode(self, mode: int) -> None: | |
| _webrtcvad.set_mode(self._vad, mode) | |
| def is_speech(self, buf: bytes, sr: int = TARGET_SR) -> bool: | |
| return _webrtcvad.process(self._vad, sr, buf, len(buf) // 2) | |
| # --------------------------------------------------------------------------- | |
| # Audio helpers | |
| # --------------------------------------------------------------------------- | |
| def preprocess_audio(wav: torch.Tensor, sr: int) -> torch.Tensor: | |
| """Resample β mono β contiguous 1-D float tensor @ TARGET_SR.""" | |
| if sr != TARGET_SR: | |
| wav = torchaudio.functional.resample(wav, sr, TARGET_SR) | |
| if wav.ndim > 1: | |
| wav = wav.mean(0) | |
| wav = wav.squeeze() | |
| return (wav if wav.ndim == 1 else wav.unsqueeze(0)).contiguous() | |
| def _frame_features( | |
| w16: np.ndarray, | |
| vad: Vad, | |
| n: int, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| speech = np.zeros(n) | |
| energy = np.zeros(n) | |
| for i in range(n): | |
| frame = w16[i * _VAD_FRAME_LEN : (i + 1) * _VAD_FRAME_LEN] | |
| try: | |
| speech[i] = vad.is_speech(frame.tobytes()) | |
| except Exception: | |
| pass | |
| energy[i] = np.sqrt(np.mean(frame.astype(np.float32) ** 2) + 1e-8) | |
| mx = energy.max() | |
| if mx >= 1e-5: | |
| energy /= mx | |
| return speech, energy | |
| def _cost_signal(speech: np.ndarray, energy: np.ndarray) -> np.ndarray: | |
| smooth = np.convolve(speech, np.ones(_SMOOTH_WIN) / _SMOOTH_WIN, mode="same") | |
| return _COST_W_SPEECH * smooth + _COST_W_ENERGY * energy | |
| def get_segments( | |
| wav: torch.Tensor, | |
| vad: Vad, | |
| max_s: float = _DEFAULT_MAX_S, | |
| ) -> list[tuple[int, int]]: | |
| """ | |
| Segment *wav* into [(start_sample, end_sample), ...] by VAD cost minima. | |
| Always returns at least one segment covering the full waveform. | |
| """ | |
| w16 = (wav.clamp(-1, 1) * 32_767).to(torch.int16).cpu().numpy() | |
| n_frames = len(w16) // _VAD_FRAME_LEN | |
| total = len(wav) | |
| if n_frames == 0: | |
| return [(0, total)] | |
| speech, energy = _frame_features(w16, vad, n_frames) | |
| if energy.max() < 1e-5: # silent file | |
| return [(0, total)] | |
| cost = _cost_signal(speech, energy) | |
| min_frames = int(_MIN_SEG_S / 0.03) | |
| max_frames = int(max_s * TARGET_SR / _VAD_FRAME_LEN) | |
| segments: list[tuple[int, int]] = [] | |
| s = 0 | |
| while s < n_frames: | |
| if n_frames - s <= max_frames: | |
| start, end = s * _VAD_FRAME_LEN, total | |
| if start < end: | |
| segments.append((start, end)) | |
| break | |
| lo = s + min_frames | |
| hi = min(s + max_frames, n_frames) | |
| cut = (lo + int(np.argmin(cost[lo:hi]))) if lo < hi else hi | |
| cut = max(cut, s + 1) | |
| start, end = s * _VAD_FRAME_LEN, min(cut * _VAD_FRAME_LEN, total) | |
| if start < end: | |
| segments.append((start, end)) | |
| s = cut | |
| return segments or [(0, total)] | |
| # --------------------------------------------------------------------------- | |
| # Dataset | |
| # --------------------------------------------------------------------------- | |
| class AudioDataset(Dataset): | |
| """ | |
| VAD-segmented waveform as a torch Dataset. | |
| Each item: (segment_tensor, {"start_s": float, "end_s": float}) | |
| Timestamps are LOCAL (relative to the start of *wav*). | |
| """ | |
| def __init__( | |
| self, | |
| wav: torch.Tensor, | |
| sr: int, | |
| vad_mode: int = _DEFAULT_VAD_MODE, | |
| max_s: float = _DEFAULT_MAX_S, | |
| ) -> None: | |
| self.wav = preprocess_audio(wav, sr) | |
| self.sr = TARGET_SR | |
| # Each worker needs its own Vad instance (not thread-safe to share) | |
| self.ts = get_segments(self.wav, Vad(vad_mode), max_s) | |
| # ------------------------------------------------------------------ | |
| def __len__(self) -> int: | |
| return len(self.ts) | |
| def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]: | |
| s, e = self.ts[idx] | |
| segment = self.wav[s:e] | |
| if segment.numel() == 0: | |
| segment = self.wav[:1] | |
| return segment, {"start_s": s / self.sr, "end_s": e / self.sr} | |
| # Convenience: safe first / last (handles single-segment edge case) | |
| def first(self) -> tuple[torch.Tensor, dict]: | |
| return self[0] | |
| def last(self) -> tuple[torch.Tensor, dict]: | |
| return self[len(self) - 1] | |
| def is_single_segment(self) -> bool: | |
| return len(self) == 1 | |
| # --------------------------------------------------------------------------- | |
| # Collate | |
| # --------------------------------------------------------------------------- | |
| def collate(batch: list) -> tuple[torch.Tensor, torch.Tensor, dict]: | |
| """Pad variable-length segments; stack metadata tensors.""" | |
| audio, metas = zip(*batch) | |
| lengths = torch.tensor([x.numel() for x in audio]) | |
| padded = pad_sequence(audio, batch_first=True) | |
| meta = { | |
| "start_s": torch.tensor([m["start_s"] for m in metas]), | |
| "end_s": torch.tensor([m["end_s"] for m in metas]), | |
| } | |
| return padded, lengths, meta | |
| # --------------------------------------------------------------------------- | |
| # AudioSplitter | |
| # --------------------------------------------------------------------------- | |
| class AudioSplitter: | |
| """ | |
| Split a large audio/video file into fixed-duration M4A chunks via FFmpeg. | |
| Usage | |
| ----- | |
| with AudioSplitter("long.mp4", split_min=10) as splitter: | |
| for path in splitter.chunks: # sorted list[Path] | |
| wav, sr = torchaudio.load(path) | |
| ... | |
| """ | |
| def __init__(self, path: str | Path, split_min: float) -> None: | |
| self.path = Path(path) | |
| self.split_sec = int(split_min * 60) | |
| self.chunks: list[Path] = [] | |
| self._tmp = None | |
| if not self.path.exists(): | |
| raise FileNotFoundError(self.path) | |
| def __enter__(self) -> "AudioSplitter": | |
| self._tmp = tempfile.TemporaryDirectory() | |
| work = Path(self._tmp.name) | |
| pattern = str(work / f"{self.path.stem}_chunk_%03d.m4a") | |
| try: | |
| ( | |
| ffmpeg | |
| .input(str(self.path)) | |
| .output( | |
| pattern, | |
| format = "segment", | |
| segment_time = self.split_sec, | |
| reset_timestamps = 1, | |
| map = "0:a", | |
| acodec = "aac", | |
| ar = TARGET_SR, | |
| ac = 1, | |
| ) | |
| .overwrite_output() | |
| .run(quiet=True) | |
| ) | |
| except ffmpeg.Error as e: | |
| # Decode stderr for a useful error message | |
| stderr = e.stderr.decode(errors="replace") if e.stderr else "(no stderr)" | |
| raise RuntimeError( | |
| f"FFmpeg failed while splitting '{self.path}':\n{stderr}" | |
| ) from e | |
| self.chunks = sorted(work.glob("*_chunk_*.m4a")) | |
| if not self.chunks: | |
| raise RuntimeError(f"FFmpeg produced no chunks from {self.path}") | |
| return self | |
| def __exit__(self, *_) -> None: | |
| if self._tmp: | |
| self._tmp.cleanup() |