""" encoder.py ---------- Extracts wav2vec2 frame-level embeddings from a waveform. Uses facebook/wav2vec2-base-960h — same model used during training. Output: [N_frames, 768] float32 tensor """ import torch import torchaudio import soundfile as sf import numpy as np import logging from pathlib import Path logger = logging.getLogger(__name__) TARGET_SR = 16_000 PRETRAINED_MODEL = "facebook/wav2vec2-base-960h" class Wav2Vec2Encoder: """ Loads facebook/wav2vec2-base-960h once and extracts last_hidden_state embeddings per utterance. """ def __init__(self, device: torch.device = None): if device is None: device = torch.device("cpu") self.device = device self._model = None def load(self): if self._model is not None: return logger.info(f"Loading encoder: {PRETRAINED_MODEL}") from transformers import Wav2Vec2Model self._model = Wav2Vec2Model.from_pretrained(PRETRAINED_MODEL) self._model.eval() self._model.to(self.device) for p in self._model.parameters(): p.requires_grad = False logger.info("Encoder loaded and frozen.") @torch.inference_mode() def encode(self, waveform: torch.Tensor) -> torch.Tensor: """ Args: waveform: [1, T] float32 at 16kHz Returns: embeddings: [N_frames, 768] float32 """ self.load() x = waveform.squeeze(0).unsqueeze(0).to(self.device) # [1, T] out = self._model(input_values=x) emb = out.last_hidden_state.squeeze(0).cpu() # [N, 768] return emb.float() def load_waveform(audio_bytes: bytes) -> torch.Tensor: """ Load audio from raw bytes — handles WebM, Opus, WAV, OGG, MP3. Browser MediaRecorder outputs WebM/Opus by default. Uses ffmpeg to convert any format → WAV PCM before reading. Returns [1, T] float32 tensor at 16kHz. """ import io import subprocess import tempfile import os # Write raw bytes to a temp file (unknown format) with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp_in: tmp_in.write(audio_bytes) tmp_in_path = tmp_in.name tmp_out_path = tmp_in_path + ".wav" try: # ffmpeg: convert any format → 16kHz mono WAV PCM result = subprocess.run( [ "ffmpeg", "-y", "-i", tmp_in_path, "-ar", str(TARGET_SR), "-ac", "1", "-f", "wav", tmp_out_path, ], capture_output=True, timeout=30, ) if result.returncode != 0: raise RuntimeError( f"ffmpeg failed: {result.stderr.decode()[-300:]}" ) # Read the converted WAV data, sr = sf.read(tmp_out_path, dtype="float32", always_2d=True) waveform = torch.from_numpy(data.T) # [C, T] finally: os.unlink(tmp_in_path) if os.path.exists(tmp_out_path): os.unlink(tmp_out_path) # Mix down to mono (ffmpeg already does -ac 1, but just in case) if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if needed (ffmpeg already does -ar 16000, but just in case) if sr != TARGET_SR: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR) waveform = resampler(waveform) return waveform # [1, T] # Singleton _encoder: Wav2Vec2Encoder | None = None def get_encoder(device: torch.device = None) -> Wav2Vec2Encoder: global _encoder if _encoder is None: _encoder = Wav2Vec2Encoder(device=device) return _encoder