|
|
""" |
|
|
HuggingFace Inference Endpoint handler for Kurdish/Persian Whisper ASR. |
|
|
|
|
|
Accepts audio (binary, base64, or filepath) and returns transcribed text. |
|
|
Default model: whisper-largev3 full fine-tune. |
|
|
""" |
|
|
|
|
|
import base64 |
|
|
import gc |
|
|
import io |
|
|
import logging |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torchaudio |
|
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
SAMPLE_RATE = 16_000 |
|
|
CHUNK_SECONDS = 30 |
|
|
CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE |
|
|
|
|
|
MODELS = { |
|
|
"small": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted", |
|
|
"full": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full", |
|
|
} |
|
|
|
|
|
DEFAULT_MODEL = "full" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _audio_bytes_to_numpy(raw: bytes) -> np.ndarray: |
|
|
"""Convert raw audio bytes to float32 mono 16 kHz numpy array. |
|
|
|
|
|
Uses torchaudio (in-memory) instead of shelling out to ffmpeg. |
|
|
""" |
|
|
buf = io.BytesIO(raw) |
|
|
waveform, sr = torchaudio.load(buf) |
|
|
|
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
if sr != SAMPLE_RATE: |
|
|
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE) |
|
|
|
|
|
return waveform.squeeze(0).numpy() |
|
|
|
|
|
|
|
|
def _chunk(audio: np.ndarray) -> list[np.ndarray]: |
|
|
if len(audio) <= CHUNK_SAMPLES: |
|
|
return [audio] |
|
|
return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
""" |
|
|
HuggingFace Inference Endpoint handler. |
|
|
|
|
|
Request format: |
|
|
{ |
|
|
"inputs": <base64-encoded audio OR raw bytes>, |
|
|
"parameters": { |
|
|
"model": "full" | "small", # default: "full" |
|
|
"language": "fa" # default: "fa" |
|
|
} |
|
|
} |
|
|
|
|
|
Response format: |
|
|
{"text": "transcribed text here"} |
|
|
""" |
|
|
|
|
|
def __init__(self, path: str = ""): |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self._model: WhisperForConditionalGeneration | None = None |
|
|
self._processor: WhisperProcessor | None = None |
|
|
self._loaded_name: str | None = None |
|
|
self._dtype = torch.float32 |
|
|
|
|
|
|
|
|
if path and (Path(path) / "config.json").exists(): |
|
|
MODELS["full"] = Path(path) |
|
|
|
|
|
self._load(DEFAULT_MODEL) |
|
|
|
|
|
def __call__(self, data: dict) -> dict: |
|
|
inputs = data.get("inputs") |
|
|
params = data.get("parameters", {}) or {} |
|
|
model_name = params.get("model", DEFAULT_MODEL) |
|
|
language = params.get("language", "fa") |
|
|
|
|
|
if not inputs: |
|
|
return {"error": "No audio provided in 'inputs'."} |
|
|
|
|
|
if model_name != self._loaded_name: |
|
|
self._load(model_name) |
|
|
|
|
|
audio = self._resolve_audio(inputs) |
|
|
text = self._transcribe(audio, language) |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, name: str): |
|
|
if name not in MODELS: |
|
|
raise ValueError(f"Unknown model '{name}'. Choose from: {list(MODELS.keys())}") |
|
|
|
|
|
if name == self._loaded_name: |
|
|
return |
|
|
|
|
|
self._unload() |
|
|
model_path = str(MODELS[name]) |
|
|
is_cuda = self.device.type == "cuda" |
|
|
|
|
|
self._processor = WhisperProcessor.from_pretrained(model_path) |
|
|
|
|
|
|
|
|
model = self._load_model(model_path, is_cuda) |
|
|
|
|
|
model.config.use_cache = True |
|
|
model.generation_config.forced_decoder_ids = None |
|
|
|
|
|
if not is_cuda and next(model.parameters()).device.type != "cpu": |
|
|
model.to(self.device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if is_cuda and not getattr(model.config, "_attn_implementation", None) == "flash_attention_2": |
|
|
try: |
|
|
model = model.to_bettertransformer() |
|
|
log.info("Using BetterTransformer (SDPA kernels).") |
|
|
except Exception: |
|
|
log.info("BetterTransformer unavailable, using default attention.") |
|
|
|
|
|
|
|
|
if is_cuda and hasattr(torch, "compile"): |
|
|
try: |
|
|
model = torch.compile(model, mode="reduce-overhead") |
|
|
log.info("Model compiled with torch.compile (reduce-overhead).") |
|
|
except Exception: |
|
|
log.info("torch.compile unavailable, skipping.") |
|
|
|
|
|
self._model = model |
|
|
self._dtype = torch.float16 if is_cuda else torch.float32 |
|
|
self._loaded_name = name |
|
|
|
|
|
def _load_model( |
|
|
self, model_path: str, is_cuda: bool, |
|
|
) -> WhisperForConditionalGeneration: |
|
|
"""Load model with best available acceleration, falling back gracefully.""" |
|
|
|
|
|
can_flash = ( |
|
|
is_cuda |
|
|
and torch.cuda.get_device_capability()[0] >= 8 |
|
|
) |
|
|
if can_flash: |
|
|
try: |
|
|
return WhisperForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16, |
|
|
attn_implementation="flash_attention_2", |
|
|
device_map="auto", |
|
|
) |
|
|
except (ImportError, ValueError, RuntimeError) as exc: |
|
|
log.info("Flash Attention 2 unavailable (%s), trying standard load.", exc) |
|
|
|
|
|
|
|
|
if is_cuda: |
|
|
try: |
|
|
return WhisperForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
) |
|
|
except (ImportError, ValueError, RuntimeError) as exc: |
|
|
log.info("Auto device_map failed (%s), falling back to manual.", exc) |
|
|
|
|
|
|
|
|
dtype = torch.float16 if is_cuda else torch.float32 |
|
|
model = WhisperForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
quantization_config=None, |
|
|
torch_dtype=dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
model.to(self.device) |
|
|
return model |
|
|
|
|
|
def _unload(self): |
|
|
del self._model, self._processor |
|
|
self._model = None |
|
|
self._processor = None |
|
|
self._loaded_name = None |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resolve_audio(self, inputs) -> np.ndarray: |
|
|
"""Accept base64 string or raw bytes.""" |
|
|
if isinstance(inputs, str): |
|
|
raw = base64.b64decode(inputs) |
|
|
elif isinstance(inputs, bytes): |
|
|
raw = inputs |
|
|
else: |
|
|
raise ValueError("'inputs' must be base64-encoded string or raw bytes.") |
|
|
|
|
|
return _audio_bytes_to_numpy(raw) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transcribe(self, audio: np.ndarray, language: str) -> str: |
|
|
assert self._model is not None and self._processor is not None |
|
|
|
|
|
chunks = _chunk(audio) |
|
|
|
|
|
|
|
|
if len(chunks) > 1: |
|
|
return self._transcribe_batched(chunks, language) |
|
|
|
|
|
return self._transcribe_single(chunks[0], language) |
|
|
|
|
|
def _transcribe_single(self, audio: np.ndarray, language: str) -> str: |
|
|
assert self._model is not None and self._processor is not None |
|
|
|
|
|
features = self._processor( |
|
|
audio, sampling_rate=SAMPLE_RATE, return_tensors="pt", |
|
|
) |
|
|
input_features = features.input_features.to(self.device, dtype=self._dtype) |
|
|
|
|
|
with torch.no_grad(), torch.autocast( |
|
|
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda", |
|
|
): |
|
|
ids = self._model.generate( |
|
|
input_features, |
|
|
language=language, |
|
|
task="transcribe", |
|
|
max_new_tokens=440, |
|
|
) |
|
|
|
|
|
return self._processor.batch_decode( |
|
|
ids, skip_special_tokens=True, |
|
|
)[0].strip() |
|
|
|
|
|
def _transcribe_batched(self, chunks: list[np.ndarray], language: str) -> str: |
|
|
assert self._model is not None and self._processor is not None |
|
|
|
|
|
|
|
|
padded = [] |
|
|
for c in chunks: |
|
|
if len(c) < CHUNK_SAMPLES: |
|
|
c = np.pad(c, (0, CHUNK_SAMPLES - len(c))) |
|
|
padded.append(c) |
|
|
|
|
|
features = self._processor( |
|
|
padded, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True, |
|
|
) |
|
|
input_features = features.input_features.to(self.device, dtype=self._dtype) |
|
|
|
|
|
with torch.no_grad(), torch.autocast( |
|
|
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda", |
|
|
): |
|
|
ids = self._model.generate( |
|
|
input_features, |
|
|
language=language, |
|
|
task="transcribe", |
|
|
max_new_tokens=440, |
|
|
) |
|
|
|
|
|
texts = self._processor.batch_decode( |
|
|
ids, skip_special_tokens=True, |
|
|
) |
|
|
|
|
|
return " ".join(t.strip() for t in texts if t.strip()) |
|
|
|