Shruti / utils.py
shethjenil's picture
Update utils.py
5d2ec8b verified
"""
utils.py β€” VAD-based audio segmentation & dataset utilities.
Public API
----------
TARGET_SR : int
Vad : VAD wrapper
AudioDataset : torch Dataset (VAD-segmented chunks)
AudioSplitter : context manager β€” splits large files into M4A chunks
collate : DataLoader collate function
get_segments : (wav, vad) β†’ list[(start_sample, end_sample)]
preprocess_audio: (wav, sr) β†’ mono 16 kHz tensor
"""
from __future__ import annotations
import tempfile
from pathlib import Path
import _webrtcvad
import ffmpeg
import numpy as np
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
__all__ = [
"TARGET_SR",
"Vad",
"AudioDataset",
"AudioSplitter",
"collate",
"get_segments",
"preprocess_audio",
]
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
TARGET_SR = 16_000
_VAD_FRAME_LEN = 480 # 30 ms @ 16 kHz β€” WebRTC requirement
_SMOOTH_WIN = 5
_COST_W_SPEECH = 0.7
_COST_W_ENERGY = 0.3
_MIN_SEG_S = 8.0
_DEFAULT_MAX_S = 18.0
_DEFAULT_VAD_MODE = 3
# ---------------------------------------------------------------------------
# VAD
# ---------------------------------------------------------------------------
class Vad:
"""Thin, stateless-feeling wrapper around native WebRTC VAD."""
def __init__(self, mode: int = _DEFAULT_VAD_MODE) -> None:
self._vad = _webrtcvad.create()
_webrtcvad.init(self._vad)
self.set_mode(mode)
def set_mode(self, mode: int) -> None:
_webrtcvad.set_mode(self._vad, mode)
def is_speech(self, buf: bytes, sr: int = TARGET_SR) -> bool:
return _webrtcvad.process(self._vad, sr, buf, len(buf) // 2)
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def preprocess_audio(wav: torch.Tensor, sr: int) -> torch.Tensor:
"""Resample β†’ mono β†’ contiguous 1-D float tensor @ TARGET_SR."""
if sr != TARGET_SR:
wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
if wav.ndim > 1:
wav = wav.mean(0)
wav = wav.squeeze()
return (wav if wav.ndim == 1 else wav.unsqueeze(0)).contiguous()
def _frame_features(
w16: np.ndarray,
vad: Vad,
n: int,
) -> tuple[np.ndarray, np.ndarray]:
speech = np.zeros(n)
energy = np.zeros(n)
for i in range(n):
frame = w16[i * _VAD_FRAME_LEN : (i + 1) * _VAD_FRAME_LEN]
try:
speech[i] = vad.is_speech(frame.tobytes())
except Exception:
pass
energy[i] = np.sqrt(np.mean(frame.astype(np.float32) ** 2) + 1e-8)
mx = energy.max()
if mx >= 1e-5:
energy /= mx
return speech, energy
def _cost_signal(speech: np.ndarray, energy: np.ndarray) -> np.ndarray:
smooth = np.convolve(speech, np.ones(_SMOOTH_WIN) / _SMOOTH_WIN, mode="same")
return _COST_W_SPEECH * smooth + _COST_W_ENERGY * energy
def get_segments(
wav: torch.Tensor,
vad: Vad,
max_s: float = _DEFAULT_MAX_S,
) -> list[tuple[int, int]]:
"""
Segment *wav* into [(start_sample, end_sample), ...] by VAD cost minima.
Always returns at least one segment covering the full waveform.
"""
w16 = (wav.clamp(-1, 1) * 32_767).to(torch.int16).cpu().numpy()
n_frames = len(w16) // _VAD_FRAME_LEN
total = len(wav)
if n_frames == 0:
return [(0, total)]
speech, energy = _frame_features(w16, vad, n_frames)
if energy.max() < 1e-5: # silent file
return [(0, total)]
cost = _cost_signal(speech, energy)
min_frames = int(_MIN_SEG_S / 0.03)
max_frames = int(max_s * TARGET_SR / _VAD_FRAME_LEN)
segments: list[tuple[int, int]] = []
s = 0
while s < n_frames:
if n_frames - s <= max_frames:
start, end = s * _VAD_FRAME_LEN, total
if start < end:
segments.append((start, end))
break
lo = s + min_frames
hi = min(s + max_frames, n_frames)
cut = (lo + int(np.argmin(cost[lo:hi]))) if lo < hi else hi
cut = max(cut, s + 1)
start, end = s * _VAD_FRAME_LEN, min(cut * _VAD_FRAME_LEN, total)
if start < end:
segments.append((start, end))
s = cut
return segments or [(0, total)]
# ---------------------------------------------------------------------------
# Dataset
# ---------------------------------------------------------------------------
class AudioDataset(Dataset):
"""
VAD-segmented waveform as a torch Dataset.
Each item: (segment_tensor, {"start_s": float, "end_s": float})
Timestamps are LOCAL (relative to the start of *wav*).
"""
def __init__(
self,
wav: torch.Tensor,
sr: int,
vad_mode: int = _DEFAULT_VAD_MODE,
max_s: float = _DEFAULT_MAX_S,
) -> None:
self.wav = preprocess_audio(wav, sr)
self.sr = TARGET_SR
# Each worker needs its own Vad instance (not thread-safe to share)
self.ts = get_segments(self.wav, Vad(vad_mode), max_s)
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.ts)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]:
s, e = self.ts[idx]
segment = self.wav[s:e]
if segment.numel() == 0:
segment = self.wav[:1]
return segment, {"start_s": s / self.sr, "end_s": e / self.sr}
# Convenience: safe first / last (handles single-segment edge case)
@property
def first(self) -> tuple[torch.Tensor, dict]:
return self[0]
@property
def last(self) -> tuple[torch.Tensor, dict]:
return self[len(self) - 1]
@property
def is_single_segment(self) -> bool:
return len(self) == 1
# ---------------------------------------------------------------------------
# Collate
# ---------------------------------------------------------------------------
def collate(batch: list) -> tuple[torch.Tensor, torch.Tensor, dict]:
"""Pad variable-length segments; stack metadata tensors."""
audio, metas = zip(*batch)
lengths = torch.tensor([x.numel() for x in audio])
padded = pad_sequence(audio, batch_first=True)
meta = {
"start_s": torch.tensor([m["start_s"] for m in metas]),
"end_s": torch.tensor([m["end_s"] for m in metas]),
}
return padded, lengths, meta
# ---------------------------------------------------------------------------
# AudioSplitter
# ---------------------------------------------------------------------------
class AudioSplitter:
"""
Split a large audio/video file into fixed-duration M4A chunks via FFmpeg.
Usage
-----
with AudioSplitter("long.mp4", split_min=10) as splitter:
for path in splitter.chunks: # sorted list[Path]
wav, sr = torchaudio.load(path)
...
"""
def __init__(self, path: str | Path, split_min: float) -> None:
self.path = Path(path)
self.split_sec = int(split_min * 60)
self.chunks: list[Path] = []
self._tmp = None
if not self.path.exists():
raise FileNotFoundError(self.path)
def __enter__(self) -> "AudioSplitter":
self._tmp = tempfile.TemporaryDirectory()
work = Path(self._tmp.name)
pattern = str(work / f"{self.path.stem}_chunk_%03d.m4a")
try:
(
ffmpeg
.input(str(self.path))
.output(
pattern,
format = "segment",
segment_time = self.split_sec,
reset_timestamps = 1,
map = "0:a",
acodec = "aac",
ar = TARGET_SR,
ac = 1,
)
.overwrite_output()
.run(quiet=True)
)
except ffmpeg.Error as e:
# Decode stderr for a useful error message
stderr = e.stderr.decode(errors="replace") if e.stderr else "(no stderr)"
raise RuntimeError(
f"FFmpeg failed while splitting '{self.path}':\n{stderr}"
) from e
self.chunks = sorted(work.glob("*_chunk_*.m4a"))
if not self.chunks:
raise RuntimeError(f"FFmpeg produced no chunks from {self.path}")
return self
def __exit__(self, *_) -> None:
if self._tmp:
self._tmp.cleanup()