File size: 7,096 Bytes
dd27c0d
 
 
 
 
 
 
 
 
d0e224d
dd27c0d
 
 
 
 
 
 
 
d0e224d
dd27c0d
 
d0e224d
 
 
dd27c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0e224d
dd27c0d
 
 
d0e224d
 
 
 
dd27c0d
 
 
 
 
 
 
 
 
 
 
 
 
d0e224d
 
dd27c0d
 
 
 
 
 
 
d0e224d
dd27c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0e224d
dd27c0d
 
d0e224d
dd27c0d
 
 
d0e224d
dd27c0d
 
 
d0e224d
dd27c0d
 
 
d0e224d
dd27c0d
 
d0e224d
dd27c0d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Custom Inference Handler for VibeVoice-ASR on Hugging Face Inference Endpoints.

Setup:
1. Duplicate the microsoft/VibeVoice-ASR repo to your own HF account
2. Add this handler.py and the accompanying requirements.txt to the repo root
3. Deploy as an Inference Endpoint with a GPU instance (min ~18GB VRAM)
"""

import base64
import io
import os
import re
import tempfile
import logging
from typing import Any, Dict, List

import torch
import numpy as np

logger = logging.getLogger(__name__)


class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the VibeVoice-ASR model and processor.

        Args:
            path: Path to model weights (provided by HF Inference Endpoints).
        """
        from vibevoice.asr.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
        from vibevoice.asr.processing_vibevoice_asr import VibeVoiceASRProcessor

        logger.info(f"Loading VibeVoice-ASR model from: {path}")

        self.processor = VibeVoiceASRProcessor.from_pretrained(path)

        self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
            path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto",
            trust_remote_code=True,
        )
        self.model.eval()

        self.device = next(self.model.parameters()).device
        logger.info(f"VibeVoice-ASR loaded on device: {self.device}")

    def _load_audio(self, audio_input) -> np.ndarray:
        """
        Load audio from various input formats.

        Supports:
        - base64-encoded string
        - raw bytes
        - file path string
        """
        import librosa

        if isinstance(audio_input, str):
            if os.path.isfile(audio_input):
                audio, _ = librosa.load(audio_input, sr=16000, mono=True)
                return audio
            else:
                # Assume base64
                audio_bytes = base64.b64decode(audio_input)
        elif isinstance(audio_input, bytes):
            audio_bytes = audio_input
        else:
            raise ValueError(
                f"Unsupported audio input type: {type(audio_input)}. "
                "Expected base64 string, bytes, or file path."
            )

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name

        try:
            audio, _ = librosa.load(tmp_path, sr=16000, mono=True)
        finally:
            os.unlink(tmp_path)

        return audio

    def _parse_transcription(self, raw_text: str) -> List[Dict[str, Any]]:
        """
        Parse the raw model output into structured segments.

        VibeVoice-ASR outputs text in the format:
        <speaker:0><start:0.00><end:13.43> Hello, how are you?
        """
        segments = []
        pattern = r"<speaker:(\d+)><start:([\d.]+)><end:([\d.]+)>\s*(.*?)(?=<speaker:|\Z)"
        matches = re.finditer(pattern, raw_text, re.DOTALL)

        for match in matches:
            speaker_id = int(match.group(1))
            start_time = float(match.group(2))
            end_time = float(match.group(3))
            text = match.group(4).strip()

            if text:
                segments.append({
                    "speaker": f"Speaker {speaker_id}",
                    "start": start_time,
                    "end": end_time,
                    "timestamp": f"{start_time:.2f} - {end_time:.2f}",
                    "text": text,
                })

        return segments

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Process an inference request.

        Request body:
        {
            "inputs": "<base64-encoded-audio>",
            "parameters": {                     # all optional
                "hotwords": "term1, term2",
                "max_new_tokens": 8192,
                "temperature": 0.0,
                "top_p": 0.9,
                "repetition_penalty": 1.0
            }
        }

        Returns:
        {
            "transcription": "plain text transcription",
            "raw": "raw model output with tags",
            "segments": [
                {
                    "speaker": "Speaker 0",
                    "start": 0.0,
                    "end": 13.43,
                    "timestamp": "0.00 - 13.43",
                    "text": "Hello, how are you?"
                }
            ],
            "duration": 78.3
        }
        """
        audio_input = data.get("inputs", data)
        parameters = data.get("parameters", {})

        hotwords = parameters.get("hotwords", "")
        max_new_tokens = parameters.get("max_new_tokens", 8192)
        temperature = parameters.get("temperature", 0.0)
        top_p = parameters.get("top_p", 0.9)
        repetition_penalty = parameters.get("repetition_penalty", 1.0)

        # Load audio
        try:
            audio = self._load_audio(audio_input)
        except Exception as e:
            return {"error": f"Failed to load audio: {str(e)}"}

        duration = len(audio) / 16000
        logger.info(f"Audio loaded: {duration:.1f}s")

        if duration > 3600:
            return {"error": "Audio exceeds 60 minute limit"}

        # Preprocess
        try:
            inputs = self.processor(
                audio=audio,
                sampling_rate=16000,
                context=hotwords if hotwords else None,
                return_tensors="pt",
            )
            inputs = {
                k: v.to(self.device) if isinstance(v, torch.Tensor) else v
                for k, v in inputs.items()
            }
        except Exception as e:
            return {"error": f"Failed to preprocess audio: {str(e)}"}

        # Generate
        try:
            generate_kwargs = {
                "max_new_tokens": max_new_tokens,
                "do_sample": temperature > 0,
            }
            if temperature > 0:
                generate_kwargs["temperature"] = temperature
                generate_kwargs["top_p"] = top_p
            if repetition_penalty != 1.0:
                generate_kwargs["repetition_penalty"] = repetition_penalty

            with torch.inference_mode():
                output_ids = self.model.generate(**inputs, **generate_kwargs)

            raw_text = self.processor.batch_decode(
                output_ids, skip_special_tokens=False
            )[0]

            for token in ["<s>", "</s>", "<pad>", "<eos>", "<bos>"]:
                raw_text = raw_text.replace(token, "")
            raw_text = raw_text.strip()

        except Exception as e:
            logger.error(f"Generation failed: {str(e)}")
            return {"error": f"Transcription failed: {str(e)}"}

        segments = self._parse_transcription(raw_text)
        plain_text = " ".join(seg["text"] for seg in segments) if segments else raw_text

        return {
            "transcription": plain_text,
            "raw": raw_text,
            "segments": segments,
            "duration": round(duration, 2),
        }