| | """ |
| | Custom Inference Handler for VibeVoice-ASR on Hugging Face Inference Endpoints. |
| | |
| | Setup: |
| | 1. Duplicate the microsoft/VibeVoice-ASR repo to your own HF account |
| | 2. Add this handler.py and the accompanying requirements.txt to the repo root |
| | 3. Deploy as an Inference Endpoint with a GPU instance (min ~18GB VRAM) |
| | """ |
| |
|
| | import base64 |
| | import io |
| | import os |
| | import re |
| | import tempfile |
| | import logging |
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| | import numpy as np |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the VibeVoice-ASR model and processor. |
| | |
| | Args: |
| | path: Path to model weights (provided by HF Inference Endpoints). |
| | """ |
| | from vibevoice.asr.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration |
| | from vibevoice.asr.processing_vibevoice_asr import VibeVoiceASRProcessor |
| |
|
| | logger.info(f"Loading VibeVoice-ASR model from: {path}") |
| |
|
| | self.processor = VibeVoiceASRProcessor.from_pretrained(path) |
| |
|
| | self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( |
| | path, |
| | torch_dtype=torch.bfloat16, |
| | attn_implementation="flash_attention_2", |
| | device_map="auto", |
| | trust_remote_code=True, |
| | ) |
| | self.model.eval() |
| |
|
| | self.device = next(self.model.parameters()).device |
| | logger.info(f"VibeVoice-ASR loaded on device: {self.device}") |
| |
|
| | def _load_audio(self, audio_input) -> np.ndarray: |
| | """ |
| | Load audio from various input formats. |
| | |
| | Supports: |
| | - base64-encoded string |
| | - raw bytes |
| | - file path string |
| | """ |
| | import librosa |
| |
|
| | if isinstance(audio_input, str): |
| | if os.path.isfile(audio_input): |
| | audio, _ = librosa.load(audio_input, sr=16000, mono=True) |
| | return audio |
| | else: |
| | |
| | audio_bytes = base64.b64decode(audio_input) |
| | elif isinstance(audio_input, bytes): |
| | audio_bytes = audio_input |
| | else: |
| | raise ValueError( |
| | f"Unsupported audio input type: {type(audio_input)}. " |
| | "Expected base64 string, bytes, or file path." |
| | ) |
| |
|
| | with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| | tmp.write(audio_bytes) |
| | tmp_path = tmp.name |
| |
|
| | try: |
| | audio, _ = librosa.load(tmp_path, sr=16000, mono=True) |
| | finally: |
| | os.unlink(tmp_path) |
| |
|
| | return audio |
| |
|
| | def _parse_transcription(self, raw_text: str) -> List[Dict[str, Any]]: |
| | """ |
| | Parse the raw model output into structured segments. |
| | |
| | VibeVoice-ASR outputs text in the format: |
| | <speaker:0><start:0.00><end:13.43> Hello, how are you? |
| | """ |
| | segments = [] |
| | pattern = r"<speaker:(\d+)><start:([\d.]+)><end:([\d.]+)>\s*(.*?)(?=<speaker:|\Z)" |
| | matches = re.finditer(pattern, raw_text, re.DOTALL) |
| |
|
| | for match in matches: |
| | speaker_id = int(match.group(1)) |
| | start_time = float(match.group(2)) |
| | end_time = float(match.group(3)) |
| | text = match.group(4).strip() |
| |
|
| | if text: |
| | segments.append({ |
| | "speaker": f"Speaker {speaker_id}", |
| | "start": start_time, |
| | "end": end_time, |
| | "timestamp": f"{start_time:.2f} - {end_time:.2f}", |
| | "text": text, |
| | }) |
| |
|
| | return segments |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process an inference request. |
| | |
| | Request body: |
| | { |
| | "inputs": "<base64-encoded-audio>", |
| | "parameters": { # all optional |
| | "hotwords": "term1, term2", |
| | "max_new_tokens": 8192, |
| | "temperature": 0.0, |
| | "top_p": 0.9, |
| | "repetition_penalty": 1.0 |
| | } |
| | } |
| | |
| | Returns: |
| | { |
| | "transcription": "plain text transcription", |
| | "raw": "raw model output with tags", |
| | "segments": [ |
| | { |
| | "speaker": "Speaker 0", |
| | "start": 0.0, |
| | "end": 13.43, |
| | "timestamp": "0.00 - 13.43", |
| | "text": "Hello, how are you?" |
| | } |
| | ], |
| | "duration": 78.3 |
| | } |
| | """ |
| | audio_input = data.get("inputs", data) |
| | parameters = data.get("parameters", {}) |
| |
|
| | hotwords = parameters.get("hotwords", "") |
| | max_new_tokens = parameters.get("max_new_tokens", 8192) |
| | temperature = parameters.get("temperature", 0.0) |
| | top_p = parameters.get("top_p", 0.9) |
| | repetition_penalty = parameters.get("repetition_penalty", 1.0) |
| |
|
| | |
| | try: |
| | audio = self._load_audio(audio_input) |
| | except Exception as e: |
| | return {"error": f"Failed to load audio: {str(e)}"} |
| |
|
| | duration = len(audio) / 16000 |
| | logger.info(f"Audio loaded: {duration:.1f}s") |
| |
|
| | if duration > 3600: |
| | return {"error": "Audio exceeds 60 minute limit"} |
| |
|
| | |
| | try: |
| | inputs = self.processor( |
| | audio=audio, |
| | sampling_rate=16000, |
| | context=hotwords if hotwords else None, |
| | return_tensors="pt", |
| | ) |
| | inputs = { |
| | k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
| | for k, v in inputs.items() |
| | } |
| | except Exception as e: |
| | return {"error": f"Failed to preprocess audio: {str(e)}"} |
| |
|
| | |
| | try: |
| | generate_kwargs = { |
| | "max_new_tokens": max_new_tokens, |
| | "do_sample": temperature > 0, |
| | } |
| | if temperature > 0: |
| | generate_kwargs["temperature"] = temperature |
| | generate_kwargs["top_p"] = top_p |
| | if repetition_penalty != 1.0: |
| | generate_kwargs["repetition_penalty"] = repetition_penalty |
| |
|
| | with torch.inference_mode(): |
| | output_ids = self.model.generate(**inputs, **generate_kwargs) |
| |
|
| | raw_text = self.processor.batch_decode( |
| | output_ids, skip_special_tokens=False |
| | )[0] |
| |
|
| | for token in ["<s>", "</s>", "<pad>", "<eos>", "<bos>"]: |
| | raw_text = raw_text.replace(token, "") |
| | raw_text = raw_text.strip() |
| |
|
| | except Exception as e: |
| | logger.error(f"Generation failed: {str(e)}") |
| | return {"error": f"Transcription failed: {str(e)}"} |
| |
|
| | segments = self._parse_transcription(raw_text) |
| | plain_text = " ".join(seg["text"] for seg in segments) if segments else raw_text |
| |
|
| | return { |
| | "transcription": plain_text, |
| | "raw": raw_text, |
| | "segments": segments, |
| | "duration": round(duration, 2), |
| | } |