File size: 3,757 Bytes
88d6872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8893fb7
 
 
88d6872
 
 
8893fb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88d6872
8893fb7
 
 
 
 
 
88d6872
 
 
8893fb7
88d6872
8893fb7
 
88d6872
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
encoder.py
----------
Extracts wav2vec2 frame-level embeddings from a waveform.
Uses facebook/wav2vec2-base-960h — same model used during training.

Output: [N_frames, 768] float32 tensor
"""

import torch
import torchaudio
import soundfile as sf
import numpy as np
import logging
from pathlib import Path

logger = logging.getLogger(__name__)

TARGET_SR        = 16_000
PRETRAINED_MODEL = "facebook/wav2vec2-base-960h"


class Wav2Vec2Encoder:
    """
    Loads facebook/wav2vec2-base-960h once and extracts
    last_hidden_state embeddings per utterance.
    """

    def __init__(self, device: torch.device = None):
        if device is None:
            device = torch.device("cpu")
        self.device = device
        self._model = None

    def load(self):
        if self._model is not None:
            return
        logger.info(f"Loading encoder: {PRETRAINED_MODEL}")
        from transformers import Wav2Vec2Model
        self._model = Wav2Vec2Model.from_pretrained(PRETRAINED_MODEL)
        self._model.eval()
        self._model.to(self.device)
        for p in self._model.parameters():
            p.requires_grad = False
        logger.info("Encoder loaded and frozen.")

    @torch.inference_mode()
    def encode(self, waveform: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveform: [1, T] float32 at 16kHz
        Returns:
            embeddings: [N_frames, 768] float32
        """
        self.load()
        x = waveform.squeeze(0).unsqueeze(0).to(self.device)  # [1, T]
        out = self._model(input_values=x)
        emb = out.last_hidden_state.squeeze(0).cpu()           # [N, 768]
        return emb.float()


def load_waveform(audio_bytes: bytes) -> torch.Tensor:
    """
    Load audio from raw bytes — handles WebM, Opus, WAV, OGG, MP3.
    Browser MediaRecorder outputs WebM/Opus by default.
    Uses ffmpeg to convert any format → WAV PCM before reading.
    Returns [1, T] float32 tensor at 16kHz.
    """
    import io
    import subprocess
    import tempfile
    import os

    # Write raw bytes to a temp file (unknown format)
    with tempfile.NamedTemporaryFile(suffix=".audio", delete=False) as tmp_in:
        tmp_in.write(audio_bytes)
        tmp_in_path = tmp_in.name

    tmp_out_path = tmp_in_path + ".wav"

    try:
        # ffmpeg: convert any format → 16kHz mono WAV PCM
        result = subprocess.run(
            [
                "ffmpeg", "-y",
                "-i", tmp_in_path,
                "-ar", str(TARGET_SR),
                "-ac", "1",
                "-f", "wav",
                tmp_out_path,
            ],
            capture_output=True,
            timeout=30,
        )
        if result.returncode != 0:
            raise RuntimeError(
                f"ffmpeg failed: {result.stderr.decode()[-300:]}"
            )

        # Read the converted WAV
        data, sr = sf.read(tmp_out_path, dtype="float32", always_2d=True)
        waveform = torch.from_numpy(data.T)  # [C, T]

    finally:
        os.unlink(tmp_in_path)
        if os.path.exists(tmp_out_path):
            os.unlink(tmp_out_path)

    # Mix down to mono (ffmpeg already does -ac 1, but just in case)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)

    # Resample if needed (ffmpeg already does -ar 16000, but just in case)
    if sr != TARGET_SR:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=TARGET_SR)
        waveform  = resampler(waveform)

    return waveform  # [1, T]


# Singleton
_encoder: Wav2Vec2Encoder | None = None

def get_encoder(device: torch.device = None) -> Wav2Vec2Encoder:
    global _encoder
    if _encoder is None:
        _encoder = Wav2Vec2Encoder(device=device)
    return _encoder