| """ |
| HuggingFace Inference Endpoints custom handler for microsoft/VibeVoice-ASR-HF. |
| |
| Deploy steps: |
| 1. Fork microsoft/VibeVoice-ASR-HF on HuggingFace |
| 2. Copy THIS file into the fork root as `handler.py` |
| 3. Create HF Inference Endpoint pointing at your fork |
| 4. Set VIBEVOICE_HF_ENDPOINT_URL + HF_TOKEN in .env |
| |
| Input : raw audio bytes (wav/mp3/flac/m4a/ogg) |
| Output : {"transcript": str, "segments": [{"Start": float, "End": float, "Speaker": int, "Content": str}]} |
| |
| Docs: |
| https://huggingface.co/docs/inference-endpoints/guides/custom_handler |
| """ |
| from __future__ import annotations |
|
|
| import tempfile |
| import os |
| from typing import Any, Dict |
|
|
| SAMPLE_RATE = 24_000 |
|
|
|
|
| class EndpointHandler: |
|
|
| def __init__(self, path: str = ""): |
| from transformers import AutoProcessor, VibeVoiceAsrForConditionalGeneration |
| import torch |
|
|
| model_path = path or "microsoft/VibeVoice-ASR-HF" |
|
|
| self.processor = AutoProcessor.from_pretrained(model_path) |
| self.model = VibeVoiceAsrForConditionalGeneration.from_pretrained( |
| model_path, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| ) |
| self.model.eval() |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Args: |
| data["inputs"] : raw audio bytes (any ffmpeg-supported format) |
| data["parameters"] : optional dict |
| prompt (str) context hint, e.g. "Medical midwife consultation" |
| tokenizer_chunk_size (int) samples per chunk — reduce for low VRAM |
| """ |
| import torch |
| from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
| audio_bytes: bytes = data.pop("inputs", data) |
| parameters: dict = data.pop("parameters", {}) or {} |
|
|
| prompt = parameters.get("prompt", "Midwife medical consultation in German") |
| tokenizer_chunk_size = parameters.get("tokenizer_chunk_size", None) |
|
|
| |
| |
| audio_np = ffmpeg_read(audio_bytes, SAMPLE_RATE) |
|
|
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
| tmp_path = tmp.name |
| import soundfile as sf |
| sf.write(tmp_path, audio_np, SAMPLE_RATE) |
|
|
| try: |
| inputs = self.processor.apply_transcription_request( |
| audio=tmp_path, |
| prompt=prompt, |
| ).to(self.model.device, self.model.dtype) |
|
|
| generate_kwargs: dict = {} |
| if tokenizer_chunk_size is not None: |
| generate_kwargs["tokenizer_chunk_size"] = int(tokenizer_chunk_size) |
|
|
| with torch.inference_mode(): |
| output_ids = self.model.generate(**inputs, **generate_kwargs) |
|
|
| generated_ids = output_ids[:, inputs["input_ids"].shape[1]:] |
| decoded = self.processor.decode(generated_ids, return_format="parsed") |
| segments: list[dict] = decoded[0] if decoded else [] |
|
|
| finally: |
| os.unlink(tmp_path) |
|
|
| transcript = " ".join(s.get("Content", "").strip() for s in segments).strip() |
|
|
| return { |
| "transcript": transcript, |
| "segments": [ |
| { |
| "Start": float(s.get("Start", 0)), |
| "End": float(s.get("End", 0)), |
| "Speaker": s.get("Speaker"), |
| "Content": s.get("Content", "").strip(), |
| } |
| for s in segments |
| ], |
| } |
|
|