dmartu commited on
Commit
dd27c0d
·
verified ·
1 Parent(s): a4a126d

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +200 -29
handler.py CHANGED
@@ -1,48 +1,219 @@
1
- from typing import Dict, Any
2
- import torch
3
- import soundfile as sf
4
- import io
 
 
 
 
 
5
  import base64
 
 
 
 
 
 
 
 
6
  import numpy as np
7
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
 
8
 
9
 
10
  class EndpointHandler:
11
- def __init__(self, path=""):
12
- self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
13
- self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
 
 
 
 
 
 
 
 
 
 
 
 
14
  path,
 
 
 
15
  trust_remote_code=True,
16
- torch_dtype=torch.float16,
17
- device_map="auto"
18
  )
19
  self.model.eval()
20
 
21
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
22
- audio_input = data.get("inputs")
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  if isinstance(audio_input, str):
25
- audio_bytes = base64.b64decode(audio_input)
26
- else:
 
 
 
 
 
27
  audio_bytes = audio_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- audio_array, sample_rate = sf.read(io.BytesIO(audio_bytes))
 
30
 
31
- if audio_array.ndim > 1:
32
- audio_array = audio_array.mean(axis=1)
 
33
 
34
- inputs = self.processor(
35
- audio_array,
36
- sampling_rate=sample_rate,
37
- return_tensors="pt"
38
- ).to(self.model.device)
39
 
40
- with torch.no_grad():
41
- generated_ids = self.model.generate(**inputs)
 
42
 
43
- transcription = self.processor.batch_decode(
44
- generated_ids,
45
- skip_special_tokens=True
46
- )[0]
47
 
48
- return {"text": transcription}
 
 
 
 
 
 
1
+ """
2
+ Custom Inference Handler for VibeVoice-ASR on Hugging Face Inference Endpoints.
3
+
4
+ Setup:
5
+ 1. Duplicate the microsoft/VibeVoice-ASR repo to your own HF account
6
+ 2. Add this handler.py and the accompanying requirements.txt to the repo root
7
+ 3. Deploy as an Inference Endpoint with a GPU instance (min ~18GB VRAM)
8
+ """
9
+
10
  import base64
11
+ import io
12
+ import os
13
+ import re
14
+ import tempfile
15
+ import logging
16
+ from typing import Any, Dict, List
17
+
18
+ import torch
19
  import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
 
23
 
24
  class EndpointHandler:
25
+ def __init__(self, path: str = ""):
26
+ """
27
+ Initialize the VibeVoice-ASR model and processor.
28
+
29
+ Args:
30
+ path: Path to model weights (provided by HF Inference Endpoints).
31
+ """
32
+ from vibevoice.asr.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
33
+ from vibevoice.asr.processing_vibevoice_asr import VibeVoiceASRProcessor
34
+
35
+ logger.info(f"Loading VibeVoice-ASR model from: {path}")
36
+
37
+ self.processor = VibeVoiceASRProcessor.from_pretrained(path)
38
+
39
+ self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
40
  path,
41
+ torch_dtype=torch.bfloat16,
42
+ attn_implementation="flash_attention_2",
43
+ device_map="auto",
44
  trust_remote_code=True,
 
 
45
  )
46
  self.model.eval()
47
 
48
+ self.device = next(self.model.parameters()).device
49
+ logger.info(f"VibeVoice-ASR loaded on device: {self.device}")
50
+
51
+ def _load_audio(self, audio_input) -> np.ndarray:
52
+ """
53
+ Load audio from various input formats.
54
+
55
+ Supports:
56
+ - base64-encoded string
57
+ - raw bytes
58
+ - file path string
59
+ """
60
+ import librosa
61
 
62
  if isinstance(audio_input, str):
63
+ if os.path.isfile(audio_input):
64
+ audio, _ = librosa.load(audio_input, sr=16000, mono=True)
65
+ return audio
66
+ else:
67
+ # Assume base64
68
+ audio_bytes = base64.b64decode(audio_input)
69
+ elif isinstance(audio_input, bytes):
70
  audio_bytes = audio_input
71
+ else:
72
+ raise ValueError(
73
+ f"Unsupported audio input type: {type(audio_input)}. "
74
+ "Expected base64 string, bytes, or file path."
75
+ )
76
+
77
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
78
+ tmp.write(audio_bytes)
79
+ tmp_path = tmp.name
80
+
81
+ try:
82
+ audio, _ = librosa.load(tmp_path, sr=16000, mono=True)
83
+ finally:
84
+ os.unlink(tmp_path)
85
+
86
+ return audio
87
+
88
+ def _parse_transcription(self, raw_text: str) -> List[Dict[str, Any]]:
89
+ """
90
+ Parse the raw model output into structured segments.
91
+
92
+ VibeVoice-ASR outputs text in the format:
93
+ <speaker:0><start:0.00><end:13.43> Hello, how are you?
94
+ """
95
+ segments = []
96
+ pattern = r"<speaker:(\d+)><start:([\d.]+)><end:([\d.]+)>\s*(.*?)(?=<speaker:|\Z)"
97
+ matches = re.finditer(pattern, raw_text, re.DOTALL)
98
+
99
+ for match in matches:
100
+ speaker_id = int(match.group(1))
101
+ start_time = float(match.group(2))
102
+ end_time = float(match.group(3))
103
+ text = match.group(4).strip()
104
+
105
+ if text:
106
+ segments.append({
107
+ "speaker": f"Speaker {speaker_id}",
108
+ "start": start_time,
109
+ "end": end_time,
110
+ "timestamp": f"{start_time:.2f} - {end_time:.2f}",
111
+ "text": text,
112
+ })
113
+
114
+ return segments
115
+
116
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
117
+ """
118
+ Process an inference request.
119
+
120
+ Request body:
121
+ {
122
+ "inputs": "<base64-encoded-audio>",
123
+ "parameters": { # all optional
124
+ "hotwords": "term1, term2",
125
+ "max_new_tokens": 8192,
126
+ "temperature": 0.0,
127
+ "top_p": 0.9,
128
+ "repetition_penalty": 1.0
129
+ }
130
+ }
131
+
132
+ Returns:
133
+ {
134
+ "transcription": "plain text transcription",
135
+ "raw": "raw model output with tags",
136
+ "segments": [
137
+ {
138
+ "speaker": "Speaker 0",
139
+ "start": 0.0,
140
+ "end": 13.43,
141
+ "timestamp": "0.00 - 13.43",
142
+ "text": "Hello, how are you?"
143
+ }
144
+ ],
145
+ "duration": 78.3
146
+ }
147
+ """
148
+ audio_input = data.get("inputs", data)
149
+ parameters = data.get("parameters", {})
150
+
151
+ hotwords = parameters.get("hotwords", "")
152
+ max_new_tokens = parameters.get("max_new_tokens", 8192)
153
+ temperature = parameters.get("temperature", 0.0)
154
+ top_p = parameters.get("top_p", 0.9)
155
+ repetition_penalty = parameters.get("repetition_penalty", 1.0)
156
+
157
+ # Load audio
158
+ try:
159
+ audio = self._load_audio(audio_input)
160
+ except Exception as e:
161
+ return {"error": f"Failed to load audio: {str(e)}"}
162
+
163
+ duration = len(audio) / 16000
164
+ logger.info(f"Audio loaded: {duration:.1f}s")
165
+
166
+ if duration > 3600:
167
+ return {"error": "Audio exceeds 60 minute limit"}
168
+
169
+ # Preprocess
170
+ try:
171
+ inputs = self.processor(
172
+ audio=audio,
173
+ sampling_rate=16000,
174
+ context=hotwords if hotwords else None,
175
+ return_tensors="pt",
176
+ )
177
+ inputs = {
178
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
179
+ for k, v in inputs.items()
180
+ }
181
+ except Exception as e:
182
+ return {"error": f"Failed to preprocess audio: {str(e)}"}
183
+
184
+ # Generate
185
+ try:
186
+ generate_kwargs = {
187
+ "max_new_tokens": max_new_tokens,
188
+ "do_sample": temperature > 0,
189
+ }
190
+ if temperature > 0:
191
+ generate_kwargs["temperature"] = temperature
192
+ generate_kwargs["top_p"] = top_p
193
+ if repetition_penalty != 1.0:
194
+ generate_kwargs["repetition_penalty"] = repetition_penalty
195
 
196
+ with torch.inference_mode():
197
+ output_ids = self.model.generate(**inputs, **generate_kwargs)
198
 
199
+ raw_text = self.processor.batch_decode(
200
+ output_ids, skip_special_tokens=False
201
+ )[0]
202
 
203
+ for token in ["<s>", "</s>", "<pad>", "<eos>", "<bos>"]:
204
+ raw_text = raw_text.replace(token, "")
205
+ raw_text = raw_text.strip()
 
 
206
 
207
+ except Exception as e:
208
+ logger.error(f"Generation failed: {str(e)}")
209
+ return {"error": f"Transcription failed: {str(e)}"}
210
 
211
+ segments = self._parse_transcription(raw_text)
212
+ plain_text = " ".join(seg["text"] for seg in segments) if segments else raw_text
 
 
213
 
214
+ return {
215
+ "transcription": plain_text,
216
+ "raw": raw_text,
217
+ "segments": segments,
218
+ "duration": round(duration, 2),
219
+ }