badrex-endpoint / handler.py
filipok's picture
redeploy: sync endpoint-badrex handler
afe40c9 verified
Raw
History Blame Contribute Delete
3.3 kB
"""
Custom HuggingFace Inference Endpoint handler for badrex Ethio-ASR
(wav2vec2-bert CTC) models β€” the endpoint counterpart of src/transcribers/badrex.py.
Protocol mirrors the MMS endpoint (endpoint/handler.py):
request : {"inputs": <base64 audio>} # no language param β€” the model IS the language
response: {"text": "..."}
Unlike the MMS handler there is NO per-request adapter switching and NO manual
logit-trim chunking: badrex models go through the stock HF ASR *pipeline*, whose
chunk_length_s/stride_length_s does the CTC logit-level stitching internally
(the "replicate the logit-trim approach" caveat applies only to hand-rolled
chunking β€” the pipeline already does it correctly). That keeps long broadcasts
from OOMing while staying byte-equivalent to the local backend's pipeline call.
Which model is served is set by the BADREX_MODEL env var on the endpoint
(default: badrex/Ethio-ASR-multilingual-1B). The handler repo holds no weights β€”
they're pulled from the Hub on cold start, same as the MMS handler.
Deploy: push this directory to a HF model repo and create a Custom Inference
Endpoint from it. See endpoint-badrex/README.md.
"""
from __future__ import annotations
import os
import re
from typing import Any, Dict, Optional
import torch
from _audio_io import AudioRequestError, load_and_validate_audio
SAMPLE_RATE = 16000
CHUNK_SEC = 30 # pipeline window
STRIDE_SEC = 5 # overlap; pipeline trims the duplicated logits at each boundary
_LANG_TAG_RE = re.compile(r"^\s*\[[A-Za-z]{2,4}\]\s*")
def _strip_lang_tag(text: Optional[str]) -> str:
return _LANG_TAG_RE.sub("", text or "").strip()
class EndpointHandler:
def __init__(self, path: str = "", **kwargs):
"""path: local repo dir (ignored β€” weights come from the Hub, like the
MMS handler). The served model id is BADREX_MODEL (env), defaulting to
the multilingual checkpoint."""
from transformers import pipeline
model_id = os.environ.get("BADREX_MODEL", "badrex/Ethio-ASR-multilingual-1B")
self.model_id = model_id
device = 0 if torch.cuda.is_available() else -1
self.pipe = pipeline(
"automatic-speech-recognition",
model=model_id,
device=device,
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
audio_input = data.get("inputs") or data.get("audio")
if audio_input is None:
return {"error": "No audio provided. Send base64 audio under 'inputs'."}
try:
audio = load_and_validate_audio(audio_input)
except AudioRequestError as exc:
return {"error": str(exc), "status": "bad_request"}
except Exception as exc:
return {"error": f"Audio decoding failed: {exc}"}
try:
out = self.pipe(
{"array": audio, "sampling_rate": SAMPLE_RATE},
chunk_length_s=CHUNK_SEC,
stride_length_s=STRIDE_SEC,
)
except Exception as exc:
return {"error": f"Transcription failed: {exc}"}
text = _strip_lang_tag((out or {}).get("text") if isinstance(out, dict) else "")
return {"text": text}