""" Custom Inference Handler for VibeVoice-ASR on Hugging Face Inference Endpoints. Setup: 1. Duplicate the microsoft/VibeVoice-ASR repo to your own HF account 2. Add this handler.py and the accompanying requirements.txt to the repo root 3. Deploy as an Inference Endpoint with a GPU instance (min ~18GB VRAM) """ import base64 import io import os import re import tempfile import logging from typing import Any, Dict, List import torch import numpy as np logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path: str = ""): """ Initialize the VibeVoice-ASR model and processor. Args: path: Path to model weights (provided by HF Inference Endpoints). """ from vibevoice.asr.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.asr.processing_vibevoice_asr import VibeVoiceASRProcessor logger.info(f"Loading VibeVoice-ASR model from: {path}") self.processor = VibeVoiceASRProcessor.from_pretrained(path) self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto", trust_remote_code=True, ) self.model.eval() self.device = next(self.model.parameters()).device logger.info(f"VibeVoice-ASR loaded on device: {self.device}") def _load_audio(self, audio_input) -> np.ndarray: """ Load audio from various input formats. Supports: - base64-encoded string - raw bytes - file path string """ import librosa if isinstance(audio_input, str): if os.path.isfile(audio_input): audio, _ = librosa.load(audio_input, sr=16000, mono=True) return audio else: # Assume base64 audio_bytes = base64.b64decode(audio_input) elif isinstance(audio_input, bytes): audio_bytes = audio_input else: raise ValueError( f"Unsupported audio input type: {type(audio_input)}. " "Expected base64 string, bytes, or file path." ) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: audio, _ = librosa.load(tmp_path, sr=16000, mono=True) finally: os.unlink(tmp_path) return audio def _parse_transcription(self, raw_text: str) -> List[Dict[str, Any]]: """ Parse the raw model output into structured segments. VibeVoice-ASR outputs text in the format: Hello, how are you? """ segments = [] pattern = r"\s*(.*?)(?= Dict[str, Any]: """ Process an inference request. Request body: { "inputs": "", "parameters": { # all optional "hotwords": "term1, term2", "max_new_tokens": 8192, "temperature": 0.0, "top_p": 0.9, "repetition_penalty": 1.0 } } Returns: { "transcription": "plain text transcription", "raw": "raw model output with tags", "segments": [ { "speaker": "Speaker 0", "start": 0.0, "end": 13.43, "timestamp": "0.00 - 13.43", "text": "Hello, how are you?" } ], "duration": 78.3 } """ audio_input = data.get("inputs", data) parameters = data.get("parameters", {}) hotwords = parameters.get("hotwords", "") max_new_tokens = parameters.get("max_new_tokens", 8192) temperature = parameters.get("temperature", 0.0) top_p = parameters.get("top_p", 0.9) repetition_penalty = parameters.get("repetition_penalty", 1.0) # Load audio try: audio = self._load_audio(audio_input) except Exception as e: return {"error": f"Failed to load audio: {str(e)}"} duration = len(audio) / 16000 logger.info(f"Audio loaded: {duration:.1f}s") if duration > 3600: return {"error": "Audio exceeds 60 minute limit"} # Preprocess try: inputs = self.processor( audio=audio, sampling_rate=16000, context=hotwords if hotwords else None, return_tensors="pt", ) inputs = { k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items() } except Exception as e: return {"error": f"Failed to preprocess audio: {str(e)}"} # Generate try: generate_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": temperature > 0, } if temperature > 0: generate_kwargs["temperature"] = temperature generate_kwargs["top_p"] = top_p if repetition_penalty != 1.0: generate_kwargs["repetition_penalty"] = repetition_penalty with torch.inference_mode(): output_ids = self.model.generate(**inputs, **generate_kwargs) raw_text = self.processor.batch_decode( output_ids, skip_special_tokens=False )[0] for token in ["", "", "", "", ""]: raw_text = raw_text.replace(token, "") raw_text = raw_text.strip() except Exception as e: logger.error(f"Generation failed: {str(e)}") return {"error": f"Transcription failed: {str(e)}"} segments = self._parse_transcription(raw_text) plain_text = " ".join(seg["text"] for seg in segments) if segments else raw_text return { "transcription": plain_text, "raw": raw_text, "segments": segments, "duration": round(duration, 2), }