File size: 7,096 Bytes
dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d d0e224d dd27c0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | """
Custom Inference Handler for VibeVoice-ASR on Hugging Face Inference Endpoints.
Setup:
1. Duplicate the microsoft/VibeVoice-ASR repo to your own HF account
2. Add this handler.py and the accompanying requirements.txt to the repo root
3. Deploy as an Inference Endpoint with a GPU instance (min ~18GB VRAM)
"""
import base64
import io
import os
import re
import tempfile
import logging
from typing import Any, Dict, List
import torch
import numpy as np
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the VibeVoice-ASR model and processor.
Args:
path: Path to model weights (provided by HF Inference Endpoints).
"""
from vibevoice.asr.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration
from vibevoice.asr.processing_vibevoice_asr import VibeVoiceASRProcessor
logger.info(f"Loading VibeVoice-ASR model from: {path}")
self.processor = VibeVoiceASRProcessor.from_pretrained(path)
self.model = VibeVoiceASRForConditionalGeneration.from_pretrained(
path,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
trust_remote_code=True,
)
self.model.eval()
self.device = next(self.model.parameters()).device
logger.info(f"VibeVoice-ASR loaded on device: {self.device}")
def _load_audio(self, audio_input) -> np.ndarray:
"""
Load audio from various input formats.
Supports:
- base64-encoded string
- raw bytes
- file path string
"""
import librosa
if isinstance(audio_input, str):
if os.path.isfile(audio_input):
audio, _ = librosa.load(audio_input, sr=16000, mono=True)
return audio
else:
# Assume base64
audio_bytes = base64.b64decode(audio_input)
elif isinstance(audio_input, bytes):
audio_bytes = audio_input
else:
raise ValueError(
f"Unsupported audio input type: {type(audio_input)}. "
"Expected base64 string, bytes, or file path."
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
audio, _ = librosa.load(tmp_path, sr=16000, mono=True)
finally:
os.unlink(tmp_path)
return audio
def _parse_transcription(self, raw_text: str) -> List[Dict[str, Any]]:
"""
Parse the raw model output into structured segments.
VibeVoice-ASR outputs text in the format:
<speaker:0><start:0.00><end:13.43> Hello, how are you?
"""
segments = []
pattern = r"<speaker:(\d+)><start:([\d.]+)><end:([\d.]+)>\s*(.*?)(?=<speaker:|\Z)"
matches = re.finditer(pattern, raw_text, re.DOTALL)
for match in matches:
speaker_id = int(match.group(1))
start_time = float(match.group(2))
end_time = float(match.group(3))
text = match.group(4).strip()
if text:
segments.append({
"speaker": f"Speaker {speaker_id}",
"start": start_time,
"end": end_time,
"timestamp": f"{start_time:.2f} - {end_time:.2f}",
"text": text,
})
return segments
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process an inference request.
Request body:
{
"inputs": "<base64-encoded-audio>",
"parameters": { # all optional
"hotwords": "term1, term2",
"max_new_tokens": 8192,
"temperature": 0.0,
"top_p": 0.9,
"repetition_penalty": 1.0
}
}
Returns:
{
"transcription": "plain text transcription",
"raw": "raw model output with tags",
"segments": [
{
"speaker": "Speaker 0",
"start": 0.0,
"end": 13.43,
"timestamp": "0.00 - 13.43",
"text": "Hello, how are you?"
}
],
"duration": 78.3
}
"""
audio_input = data.get("inputs", data)
parameters = data.get("parameters", {})
hotwords = parameters.get("hotwords", "")
max_new_tokens = parameters.get("max_new_tokens", 8192)
temperature = parameters.get("temperature", 0.0)
top_p = parameters.get("top_p", 0.9)
repetition_penalty = parameters.get("repetition_penalty", 1.0)
# Load audio
try:
audio = self._load_audio(audio_input)
except Exception as e:
return {"error": f"Failed to load audio: {str(e)}"}
duration = len(audio) / 16000
logger.info(f"Audio loaded: {duration:.1f}s")
if duration > 3600:
return {"error": "Audio exceeds 60 minute limit"}
# Preprocess
try:
inputs = self.processor(
audio=audio,
sampling_rate=16000,
context=hotwords if hotwords else None,
return_tensors="pt",
)
inputs = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()
}
except Exception as e:
return {"error": f"Failed to preprocess audio: {str(e)}"}
# Generate
try:
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
}
if temperature > 0:
generate_kwargs["temperature"] = temperature
generate_kwargs["top_p"] = top_p
if repetition_penalty != 1.0:
generate_kwargs["repetition_penalty"] = repetition_penalty
with torch.inference_mode():
output_ids = self.model.generate(**inputs, **generate_kwargs)
raw_text = self.processor.batch_decode(
output_ids, skip_special_tokens=False
)[0]
for token in ["<s>", "</s>", "<pad>", "<eos>", "<bos>"]:
raw_text = raw_text.replace(token, "")
raw_text = raw_text.strip()
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
return {"error": f"Transcription failed: {str(e)}"}
segments = self._parse_transcription(raw_text)
plain_text = " ".join(seg["text"] for seg in segments) if segments else raw_text
return {
"transcription": plain_text,
"raw": raw_text,
"segments": segments,
"duration": round(duration, 2),
} |