""" Custom HuggingFace Inference Endpoint handler for badrex Ethio-ASR (wav2vec2-bert CTC) models — the endpoint counterpart of src/transcribers/badrex.py. Protocol mirrors the MMS endpoint (endpoint/handler.py): request : {"inputs": } # no language param — the model IS the language response: {"text": "..."} Unlike the MMS handler there is NO per-request adapter switching and NO manual logit-trim chunking: badrex models go through the stock HF ASR *pipeline*, whose chunk_length_s/stride_length_s does the CTC logit-level stitching internally (the "replicate the logit-trim approach" caveat applies only to hand-rolled chunking — the pipeline already does it correctly). That keeps long broadcasts from OOMing while staying byte-equivalent to the local backend's pipeline call. Which model is served is set by the BADREX_MODEL env var on the endpoint (default: badrex/Ethio-ASR-multilingual-1B). The handler repo holds no weights — they're pulled from the Hub on cold start, same as the MMS handler. Deploy: push this directory to a HF model repo and create a Custom Inference Endpoint from it. See endpoint-badrex/README.md. """ from __future__ import annotations import os import re from typing import Any, Dict, Optional import torch from _audio_io import AudioRequestError, load_and_validate_audio SAMPLE_RATE = 16000 CHUNK_SEC = 30 # pipeline window STRIDE_SEC = 5 # overlap; pipeline trims the duplicated logits at each boundary _LANG_TAG_RE = re.compile(r"^\s*\[[A-Za-z]{2,4}\]\s*") def _strip_lang_tag(text: Optional[str]) -> str: return _LANG_TAG_RE.sub("", text or "").strip() class EndpointHandler: def __init__(self, path: str = "", **kwargs): """path: local repo dir (ignored — weights come from the Hub, like the MMS handler). The served model id is BADREX_MODEL (env), defaulting to the multilingual checkpoint.""" from transformers import pipeline model_id = os.environ.get("BADREX_MODEL", "badrex/Ethio-ASR-multilingual-1B") self.model_id = model_id device = 0 if torch.cuda.is_available() else -1 self.pipe = pipeline( "automatic-speech-recognition", model=model_id, device=device, ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: audio_input = data.get("inputs") or data.get("audio") if audio_input is None: return {"error": "No audio provided. Send base64 audio under 'inputs'."} try: audio = load_and_validate_audio(audio_input) except AudioRequestError as exc: return {"error": str(exc), "status": "bad_request"} except Exception as exc: return {"error": f"Audio decoding failed: {exc}"} try: out = self.pipe( {"array": audio, "sampling_rate": SAMPLE_RATE}, chunk_length_s=CHUNK_SEC, stride_length_s=STRIDE_SEC, ) except Exception as exc: return {"error": f"Transcription failed: {exc}"} text = _strip_lang_tag((out or {}).get("text") if isinstance(out, dict) else "") return {"text": text}