| """
|
| Custom HuggingFace Inference Endpoint handler for badrex Ethio-ASR
|
| (wav2vec2-bert CTC) models β the endpoint counterpart of src/transcribers/badrex.py.
|
|
|
| Protocol mirrors the MMS endpoint (endpoint/handler.py):
|
| request : {"inputs": <base64 audio>} # no language param β the model IS the language
|
| response: {"text": "..."}
|
|
|
| Unlike the MMS handler there is NO per-request adapter switching and NO manual
|
| logit-trim chunking: badrex models go through the stock HF ASR *pipeline*, whose
|
| chunk_length_s/stride_length_s does the CTC logit-level stitching internally
|
| (the "replicate the logit-trim approach" caveat applies only to hand-rolled
|
| chunking β the pipeline already does it correctly). That keeps long broadcasts
|
| from OOMing while staying byte-equivalent to the local backend's pipeline call.
|
|
|
| Which model is served is set by the BADREX_MODEL env var on the endpoint
|
| (default: badrex/Ethio-ASR-multilingual-1B). The handler repo holds no weights β
|
| they're pulled from the Hub on cold start, same as the MMS handler.
|
|
|
| Deploy: push this directory to a HF model repo and create a Custom Inference
|
| Endpoint from it. See endpoint-badrex/README.md.
|
| """
|
| from __future__ import annotations
|
|
|
| import os
|
| import re
|
| from typing import Any, Dict, Optional
|
|
|
| import torch
|
|
|
| from _audio_io import AudioRequestError, load_and_validate_audio
|
|
|
| SAMPLE_RATE = 16000
|
| CHUNK_SEC = 30
|
| STRIDE_SEC = 5
|
|
|
| _LANG_TAG_RE = re.compile(r"^\s*\[[A-Za-z]{2,4}\]\s*")
|
|
|
|
|
| def _strip_lang_tag(text: Optional[str]) -> str:
|
| return _LANG_TAG_RE.sub("", text or "").strip()
|
|
|
|
|
| class EndpointHandler:
|
| def __init__(self, path: str = "", **kwargs):
|
| """path: local repo dir (ignored β weights come from the Hub, like the
|
| MMS handler). The served model id is BADREX_MODEL (env), defaulting to
|
| the multilingual checkpoint."""
|
| from transformers import pipeline
|
|
|
| model_id = os.environ.get("BADREX_MODEL", "badrex/Ethio-ASR-multilingual-1B")
|
| self.model_id = model_id
|
| device = 0 if torch.cuda.is_available() else -1
|
| self.pipe = pipeline(
|
| "automatic-speech-recognition",
|
| model=model_id,
|
| device=device,
|
| )
|
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| audio_input = data.get("inputs") or data.get("audio")
|
| if audio_input is None:
|
| return {"error": "No audio provided. Send base64 audio under 'inputs'."}
|
|
|
| try:
|
| audio = load_and_validate_audio(audio_input)
|
| except AudioRequestError as exc:
|
| return {"error": str(exc), "status": "bad_request"}
|
| except Exception as exc:
|
| return {"error": f"Audio decoding failed: {exc}"}
|
|
|
| try:
|
| out = self.pipe(
|
| {"array": audio, "sampling_rate": SAMPLE_RATE},
|
| chunk_length_s=CHUNK_SEC,
|
| stride_length_s=STRIDE_SEC,
|
| )
|
| except Exception as exc:
|
| return {"error": f"Transcription failed: {exc}"}
|
|
|
| text = _strip_lang_tag((out or {}).get("text") if isinstance(out, dict) else "")
|
| return {"text": text}
|
|
|