File size: 3,295 Bytes
afe40c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""

Custom HuggingFace Inference Endpoint handler for badrex Ethio-ASR

(wav2vec2-bert CTC) models — the endpoint counterpart of src/transcribers/badrex.py.



Protocol mirrors the MMS endpoint (endpoint/handler.py):

  request : {"inputs": <base64 audio>}          # no language param — the model IS the language

  response: {"text": "..."}



Unlike the MMS handler there is NO per-request adapter switching and NO manual

logit-trim chunking: badrex models go through the stock HF ASR *pipeline*, whose

chunk_length_s/stride_length_s does the CTC logit-level stitching internally

(the "replicate the logit-trim approach" caveat applies only to hand-rolled

chunking — the pipeline already does it correctly). That keeps long broadcasts

from OOMing while staying byte-equivalent to the local backend's pipeline call.



Which model is served is set by the BADREX_MODEL env var on the endpoint

(default: badrex/Ethio-ASR-multilingual-1B). The handler repo holds no weights —

they're pulled from the Hub on cold start, same as the MMS handler.



Deploy: push this directory to a HF model repo and create a Custom Inference

Endpoint from it. See endpoint-badrex/README.md.

"""
from __future__ import annotations

import os
import re
from typing import Any, Dict, Optional

import torch

from _audio_io import AudioRequestError, load_and_validate_audio

SAMPLE_RATE = 16000
CHUNK_SEC = 30   # pipeline window
STRIDE_SEC = 5   # overlap; pipeline trims the duplicated logits at each boundary

_LANG_TAG_RE = re.compile(r"^\s*\[[A-Za-z]{2,4}\]\s*")


def _strip_lang_tag(text: Optional[str]) -> str:
    return _LANG_TAG_RE.sub("", text or "").strip()


class EndpointHandler:
    def __init__(self, path: str = "", **kwargs):
        """path: local repo dir (ignored — weights come from the Hub, like the

        MMS handler). The served model id is BADREX_MODEL (env), defaulting to

        the multilingual checkpoint."""
        from transformers import pipeline

        model_id = os.environ.get("BADREX_MODEL", "badrex/Ethio-ASR-multilingual-1B")
        self.model_id = model_id
        device = 0 if torch.cuda.is_available() else -1
        self.pipe = pipeline(
            "automatic-speech-recognition",
            model=model_id,
            device=device,
        )

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        audio_input = data.get("inputs") or data.get("audio")
        if audio_input is None:
            return {"error": "No audio provided. Send base64 audio under 'inputs'."}

        try:
            audio = load_and_validate_audio(audio_input)
        except AudioRequestError as exc:
            return {"error": str(exc), "status": "bad_request"}
        except Exception as exc:
            return {"error": f"Audio decoding failed: {exc}"}

        try:
            out = self.pipe(
                {"array": audio, "sampling_rate": SAMPLE_RATE},
                chunk_length_s=CHUNK_SEC,
                stride_length_s=STRIDE_SEC,
            )
        except Exception as exc:
            return {"error": f"Transcription failed: {exc}"}

        text = _strip_lang_tag((out or {}).get("text") if isinstance(out, dict) else "")
        return {"text": text}