VibeVoice-ASR-HF-Fork / handler.py
nivedithahn96's picture
Update handler.py
f93e144 verified
"""
HuggingFace Inference Endpoints custom handler for microsoft/VibeVoice-ASR-HF.
Deploy steps:
1. Fork microsoft/VibeVoice-ASR-HF on HuggingFace
2. Copy THIS file into the fork root as `handler.py`
3. Create HF Inference Endpoint pointing at your fork
4. Set VIBEVOICE_HF_ENDPOINT_URL + HF_TOKEN in .env
Input : raw audio bytes (wav/mp3/flac/m4a/ogg)
Output : {"transcript": str, "segments": [{"Start": float, "End": float, "Speaker": int, "Content": str}]}
Docs:
https://huggingface.co/docs/inference-endpoints/guides/custom_handler
"""
from __future__ import annotations
import tempfile
import os
from typing import Any, Dict
SAMPLE_RATE = 24_000 # VibeVoice-ASR-HF requires 24 kHz
class EndpointHandler:
def __init__(self, path: str = ""):
from transformers import AutoProcessor, VibeVoiceAsrForConditionalGeneration
import torch
model_path = path or "microsoft/VibeVoice-ASR-HF"
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = VibeVoiceAsrForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype=torch.bfloat16,
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data["inputs"] : raw audio bytes (any ffmpeg-supported format)
data["parameters"] : optional dict
prompt (str) context hint, e.g. "Medical midwife consultation"
tokenizer_chunk_size (int) samples per chunk — reduce for low VRAM
"""
import torch
from transformers.pipelines.audio_utils import ffmpeg_read
audio_bytes: bytes = data.pop("inputs", data)
parameters: dict = data.pop("parameters", {}) or {}
prompt = parameters.get("prompt", "Midwife medical consultation in German")
tokenizer_chunk_size = parameters.get("tokenizer_chunk_size", None)
# Decode audio bytes → numpy array at 24 kHz → temp wav file
# processor.apply_transcription_request() requires a file path
audio_np = ffmpeg_read(audio_bytes, SAMPLE_RATE)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp_path = tmp.name
import soundfile as sf
sf.write(tmp_path, audio_np, SAMPLE_RATE)
try:
inputs = self.processor.apply_transcription_request(
audio=tmp_path,
prompt=prompt,
).to(self.model.device, self.model.dtype)
generate_kwargs: dict = {}
if tokenizer_chunk_size is not None:
generate_kwargs["tokenizer_chunk_size"] = int(tokenizer_chunk_size)
with torch.inference_mode():
output_ids = self.model.generate(**inputs, **generate_kwargs)
generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
decoded = self.processor.decode(generated_ids, return_format="parsed")
segments: list[dict] = decoded[0] if decoded else []
finally:
os.unlink(tmp_path)
transcript = " ".join(s.get("Content", "").strip() for s in segments).strip()
return {
"transcript": transcript,
"segments": [
{
"Start": float(s.get("Start", 0)),
"End": float(s.get("End", 0)),
"Speaker": s.get("Speaker"),
"Content": s.get("Content", "").strip(),
}
for s in segments
],
}