File size: 3,647 Bytes
d2d8f35
f93e144
d2d8f35
f93e144
 
 
 
 
d2d8f35
f93e144
 
d2d8f35
 
 
 
 
 
f93e144
d2d8f35
f93e144
d2d8f35
f93e144
d2d8f35
 
f93e144
d2d8f35
f93e144
d2d8f35
f93e144
d2d8f35
f93e144
d2d8f35
f93e144
 
 
 
 
d2d8f35
f93e144
 
 
 
 
 
 
 
 
 
d2d8f35
f93e144
d2d8f35
f93e144
 
d2d8f35
f93e144
 
d2d8f35
f93e144
 
 
d2d8f35
f93e144
 
 
 
d2d8f35
f93e144
 
 
 
 
d2d8f35
f93e144
 
 
d2d8f35
f93e144
 
d2d8f35
f93e144
 
 
d2d8f35
f93e144
 
d2d8f35
f93e144
d2d8f35
 
f93e144
 
 
 
 
 
 
 
 
 
d2d8f35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
HuggingFace Inference Endpoints custom handler for microsoft/VibeVoice-ASR-HF.

Deploy steps:
  1. Fork microsoft/VibeVoice-ASR-HF on HuggingFace
  2. Copy THIS file into the fork root as `handler.py`
  3. Create HF Inference Endpoint pointing at your fork
  4. Set VIBEVOICE_HF_ENDPOINT_URL + HF_TOKEN in .env

Input  : raw audio bytes (wav/mp3/flac/m4a/ogg)
Output : {"transcript": str, "segments": [{"Start": float, "End": float, "Speaker": int, "Content": str}]}

Docs:
  https://huggingface.co/docs/inference-endpoints/guides/custom_handler
"""
from __future__ import annotations

import tempfile
import os
from typing import Any, Dict

SAMPLE_RATE = 24_000   # VibeVoice-ASR-HF requires 24 kHz


class EndpointHandler:

    def __init__(self, path: str = ""):
        from transformers import AutoProcessor, VibeVoiceAsrForConditionalGeneration
        import torch

        model_path = path or "microsoft/VibeVoice-ASR-HF"

        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = VibeVoiceAsrForConditionalGeneration.from_pretrained(
            model_path,
            device_map="auto",
            torch_dtype=torch.bfloat16,
        )
        self.model.eval()

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Args:
            data["inputs"]     : raw audio bytes (any ffmpeg-supported format)
            data["parameters"] : optional dict
                prompt              (str)  context hint, e.g. "Medical midwife consultation"
                tokenizer_chunk_size (int) samples per chunk — reduce for low VRAM
        """
        import torch
        from transformers.pipelines.audio_utils import ffmpeg_read

        audio_bytes: bytes = data.pop("inputs", data)
        parameters: dict   = data.pop("parameters", {}) or {}

        prompt               = parameters.get("prompt", "Midwife medical consultation in German")
        tokenizer_chunk_size = parameters.get("tokenizer_chunk_size", None)

        # Decode audio bytes → numpy array at 24 kHz → temp wav file
        # processor.apply_transcription_request() requires a file path
        audio_np = ffmpeg_read(audio_bytes, SAMPLE_RATE)

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp_path = tmp.name
            import soundfile as sf
            sf.write(tmp_path, audio_np, SAMPLE_RATE)

        try:
            inputs = self.processor.apply_transcription_request(
                audio=tmp_path,
                prompt=prompt,
            ).to(self.model.device, self.model.dtype)

            generate_kwargs: dict = {}
            if tokenizer_chunk_size is not None:
                generate_kwargs["tokenizer_chunk_size"] = int(tokenizer_chunk_size)

            with torch.inference_mode():
                output_ids = self.model.generate(**inputs, **generate_kwargs)

            generated_ids = output_ids[:, inputs["input_ids"].shape[1]:]
            decoded = self.processor.decode(generated_ids, return_format="parsed")
            segments: list[dict] = decoded[0] if decoded else []

        finally:
            os.unlink(tmp_path)

        transcript = " ".join(s.get("Content", "").strip() for s in segments).strip()

        return {
            "transcript": transcript,
            "segments": [
                {
                    "Start":   float(s.get("Start", 0)),
                    "End":     float(s.get("End",   0)),
                    "Speaker": s.get("Speaker"),
                    "Content": s.get("Content", "").strip(),
                }
                for s in segments
            ],
        }