heldtomaturity's picture
fix: handle WebM audio from browser
8893fb7
"""
encoder.py
----------
Extracts wav2vec2 frame-level embeddings from a waveform.
Uses facebook/wav2vec2-base-960h β€” same model used during training.
Output: [N_frames, 768] float32 tensor
"""
import torch
import torchaudio
import soundfile as sf
import numpy as np
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
TARGET_SR = 16_000
PRETRAINED_MODEL = "facebook/wav2vec2-base-960h"
class Wav2Vec2Encoder:
"""
Loads facebook/wav2vec2-base-960h once and extracts
last_hidden_state embeddings per utterance.
"""
def __init__(self, device: torch.device = None):
if device is None:
device = torch.device("cpu")
self.device = device
self._model = None
def load(self):
if self._model is not None:
return
logger.info(f"Loading encoder: {PRETRAINED_MODEL}")
from transformers import Wav2Vec2Model
self._model = Wav2Vec2Model.from_pretrained(PRETRAINED_MODEL)
self._model.eval()
self._model.to(self.device)
for p in self._model.parameters():
p.requires_grad = False
logger.info("Encoder loaded and frozen.")
@torch.inference_mode()
def encode(self, waveform: torch.Tensor) -> torch.Tensor:
"""
Args:
waveform: [1, T] float32 at 16kHz
Returns:
embeddings: [N_frames, 768] float32
"""
self.load()
x = waveform.squeeze(0).unsqueeze(0).to(self.device) # [1, T]
out = self._model(input_values=x)
emb = out.last_hidden_state.squeeze(0).cpu() # [N, 768]
return emb.float()
def load_waveform(audio_bytes: bytes) -> torch.Tensor:
"""
Load audio from raw bytes β€” handles WebM, Opus, WAV, OGG, MP3.
Browser MediaRecorder outputs WebM/Opus by default.
Uses ffmpeg to convert any format β†’ WAV PCM before reading.
Returns [1, T] float32 tensor at 16kHz.
"""
import io
import subprocess
import tempfile
import os
# Write raw bytes to a temp file (unknown format)
with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp_in:
tmp_in.write(audio_bytes)
tmp_in_path = tmp_in.name
tmp_out_path = tmp_in_path + ".wav"
try:
# ffmpeg: convert any format β†’ 16kHz mono WAV PCM
result = subprocess.run(
[
"ffmpeg", "-y",
"-i", tmp_in_path,
"-ar", str(TARGET_SR),
"-ac", "1",
"-f", "wav",
tmp_out_path,
],
capture_output=True,
timeout=30,
)
if result.returncode != 0:
raise RuntimeError(
f"ffmpeg failed: {result.stderr.decode()[-300:]}"
)
# Read the converted WAV
data, sr = sf.read(tmp_out_path, dtype="float32", always_2d=True)
waveform = torch.from_numpy(data.T) # [C, T]
finally:
os.unlink(tmp_in_path)
if os.path.exists(tmp_out_path):
os.unlink(tmp_out_path)
# Mix down to mono (ffmpeg already does -ac 1, but just in case)
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample if needed (ffmpeg already does -ar 16000, but just in case)
if sr != TARGET_SR:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
waveform = resampler(waveform)
return waveform # [1, T]
# Singleton
_encoder: Wav2Vec2Encoder | None = None
def get_encoder(device: torch.device = None) -> Wav2Vec2Encoder:
global _encoder
if _encoder is None:
_encoder = Wav2Vec2Encoder(device=device)
return _encoder