""" HuggingFace Inference Endpoint handler for Kurdish/Persian Whisper ASR. Accepts audio (binary, base64, or filepath) and returns transcribed text. Default model: whisper-largev3 full fine-tune. """ import base64 import gc import io import logging from pathlib import Path import numpy as np import torch import torchaudio from transformers import WhisperForConditionalGeneration, WhisperProcessor log = logging.getLogger(__name__) SAMPLE_RATE = 16_000 CHUNK_SECONDS = 30 CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE MODELS = { "small": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted", "full": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full", } DEFAULT_MODEL = "full" # --------------------------------------------------------------------------- # Audio helpers # --------------------------------------------------------------------------- def _audio_bytes_to_numpy(raw: bytes) -> np.ndarray: """Convert raw audio bytes to float32 mono 16 kHz numpy array. Uses torchaudio (in-memory) instead of shelling out to ffmpeg. """ buf = io.BytesIO(raw) waveform, sr = torchaudio.load(buf) # (channels, samples) # Mix to mono. if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if needed. if sr != SAMPLE_RATE: waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE) return waveform.squeeze(0).numpy() def _chunk(audio: np.ndarray) -> list[np.ndarray]: if len(audio) <= CHUNK_SAMPLES: return [audio] return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)] # --------------------------------------------------------------------------- # Endpoint handler # --------------------------------------------------------------------------- class EndpointHandler: """ HuggingFace Inference Endpoint handler. Request format: { "inputs": , "parameters": { "model": "full" | "small", # default: "full" "language": "fa" # default: "fa" } } Response format: {"text": "transcribed text here"} """ def __init__(self, path: str = ""): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._model: WhisperForConditionalGeneration | None = None self._processor: WhisperProcessor | None = None self._loaded_name: str | None = None self._dtype = torch.float32 # If HF Inference Endpoint provides a path with model files, use it. if path and (Path(path) / "config.json").exists(): MODELS["full"] = Path(path) self._load(DEFAULT_MODEL) def __call__(self, data: dict) -> dict: inputs = data.get("inputs") params = data.get("parameters", {}) or {} model_name = params.get("model", DEFAULT_MODEL) language = params.get("language", "fa") if not inputs: return {"error": "No audio provided in 'inputs'."} if model_name != self._loaded_name: self._load(model_name) audio = self._resolve_audio(inputs) text = self._transcribe(audio, language) return {"text": text} # ------------------------------------------------------------------ # Model lifecycle # ------------------------------------------------------------------ def _load(self, name: str): if name not in MODELS: raise ValueError(f"Unknown model '{name}'. Choose from: {list(MODELS.keys())}") if name == self._loaded_name: return self._unload() model_path = str(MODELS[name]) is_cuda = self.device.type == "cuda" self._processor = WhisperProcessor.from_pretrained(model_path) # type: ignore[assignment] # Try optimal load: flash attention 2 + float16 on CUDA. model = self._load_model(model_path, is_cuda) model.config.use_cache = True model.generation_config.forced_decoder_ids = None if not is_cuda and next(model.parameters()).device.type != "cpu": model.to(self.device) # type: ignore[arg-type] model.eval() # BetterTransformer fallback when Flash Attention is unavailable. if is_cuda and not getattr(model.config, "_attn_implementation", None) == "flash_attention_2": try: model = model.to_bettertransformer() # type: ignore[assignment] log.info("Using BetterTransformer (SDPA kernels).") except Exception: log.info("BetterTransformer unavailable, using default attention.") # torch.compile for graph-level optimization (warmup on first call). if is_cuda and hasattr(torch, "compile"): try: model = torch.compile(model, mode="reduce-overhead") # type: ignore[assignment] log.info("Model compiled with torch.compile (reduce-overhead).") except Exception: log.info("torch.compile unavailable, skipping.") self._model = model self._dtype = torch.float16 if is_cuda else torch.float32 self._loaded_name = name def _load_model( self, model_path: str, is_cuda: bool, ) -> WhisperForConditionalGeneration: """Load model with best available acceleration, falling back gracefully.""" # Attempt 1: Flash Attention 2 + float16 (requires Ampere / sm_80+). can_flash = ( is_cuda and torch.cuda.get_device_capability()[0] >= 8 ) if can_flash: try: return WhisperForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.float16, attn_implementation="flash_attention_2", device_map="auto", ) except (ImportError, ValueError, RuntimeError) as exc: log.info("Flash Attention 2 unavailable (%s), trying standard load.", exc) # Attempt 2: Standard CUDA load (float16, auto device map). if is_cuda: try: return WhisperForConditionalGeneration.from_pretrained( model_path, torch_dtype=torch.float16, device_map="auto", ) except (ImportError, ValueError, RuntimeError) as exc: log.info("Auto device_map failed (%s), falling back to manual.", exc) # Attempt 3: Manual load (CPU or CUDA without device_map). dtype = torch.float16 if is_cuda else torch.float32 model = WhisperForConditionalGeneration.from_pretrained( model_path, quantization_config=None, torch_dtype=dtype, low_cpu_mem_usage=True, ) model.to(self.device) # type: ignore[arg-type] return model def _unload(self): del self._model, self._processor self._model = None self._processor = None self._loaded_name = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # ------------------------------------------------------------------ # Audio resolution # ------------------------------------------------------------------ def _resolve_audio(self, inputs) -> np.ndarray: # type: ignore[override] """Accept base64 string or raw bytes.""" if isinstance(inputs, str): raw = base64.b64decode(inputs) elif isinstance(inputs, bytes): raw = inputs else: raise ValueError("'inputs' must be base64-encoded string or raw bytes.") return _audio_bytes_to_numpy(raw) # ------------------------------------------------------------------ # Inference # ------------------------------------------------------------------ def _transcribe(self, audio: np.ndarray, language: str) -> str: assert self._model is not None and self._processor is not None chunks = _chunk(audio) # Batch all chunks into a single forward pass. if len(chunks) > 1: return self._transcribe_batched(chunks, language) return self._transcribe_single(chunks[0], language) def _transcribe_single(self, audio: np.ndarray, language: str) -> str: assert self._model is not None and self._processor is not None features = self._processor( # type: ignore[operator] audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", ) input_features = features.input_features.to(self.device, dtype=self._dtype) with torch.no_grad(), torch.autocast( self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda", ): ids = self._model.generate( input_features, language=language, task="transcribe", max_new_tokens=440, ) return self._processor.batch_decode( # type: ignore[union-attr] ids, skip_special_tokens=True, )[0].strip() def _transcribe_batched(self, chunks: list[np.ndarray], language: str) -> str: assert self._model is not None and self._processor is not None # Pad shorter chunks to 30s so mel features align for stacking. padded = [] for c in chunks: if len(c) < CHUNK_SAMPLES: c = np.pad(c, (0, CHUNK_SAMPLES - len(c))) padded.append(c) features = self._processor( # type: ignore[operator] padded, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True, ) input_features = features.input_features.to(self.device, dtype=self._dtype) with torch.no_grad(), torch.autocast( self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda", ): ids = self._model.generate( input_features, language=language, task="transcribe", max_new_tokens=440, ) texts = self._processor.batch_decode( # type: ignore[union-attr] ids, skip_special_tokens=True, ) return " ".join(t.strip() for t in texts if t.strip())