File size: 10,442 Bytes
3c5114a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 |
"""
HuggingFace Inference Endpoint handler for Kurdish/Persian Whisper ASR.
Accepts audio (binary, base64, or filepath) and returns transcribed text.
Default model: whisper-largev3 full fine-tune.
"""
import base64
import gc
import io
import logging
from pathlib import Path
import numpy as np
import torch
import torchaudio
from transformers import WhisperForConditionalGeneration, WhisperProcessor
log = logging.getLogger(__name__)
SAMPLE_RATE = 16_000
CHUNK_SECONDS = 30
CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE
MODELS = {
"small": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted",
"full": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full",
}
DEFAULT_MODEL = "full"
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def _audio_bytes_to_numpy(raw: bytes) -> np.ndarray:
"""Convert raw audio bytes to float32 mono 16 kHz numpy array.
Uses torchaudio (in-memory) instead of shelling out to ffmpeg.
"""
buf = io.BytesIO(raw)
waveform, sr = torchaudio.load(buf) # (channels, samples)
# Mix to mono.
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample if needed.
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
return waveform.squeeze(0).numpy()
def _chunk(audio: np.ndarray) -> list[np.ndarray]:
if len(audio) <= CHUNK_SAMPLES:
return [audio]
return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)]
# ---------------------------------------------------------------------------
# Endpoint handler
# ---------------------------------------------------------------------------
class EndpointHandler:
"""
HuggingFace Inference Endpoint handler.
Request format:
{
"inputs": <base64-encoded audio OR raw bytes>,
"parameters": {
"model": "full" | "small", # default: "full"
"language": "fa" # default: "fa"
}
}
Response format:
{"text": "transcribed text here"}
"""
def __init__(self, path: str = ""):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._model: WhisperForConditionalGeneration | None = None
self._processor: WhisperProcessor | None = None
self._loaded_name: str | None = None
self._dtype = torch.float32
# If HF Inference Endpoint provides a path with model files, use it.
if path and (Path(path) / "config.json").exists():
MODELS["full"] = Path(path)
self._load(DEFAULT_MODEL)
def __call__(self, data: dict) -> dict:
inputs = data.get("inputs")
params = data.get("parameters", {}) or {}
model_name = params.get("model", DEFAULT_MODEL)
language = params.get("language", "fa")
if not inputs:
return {"error": "No audio provided in 'inputs'."}
if model_name != self._loaded_name:
self._load(model_name)
audio = self._resolve_audio(inputs)
text = self._transcribe(audio, language)
return {"text": text}
# ------------------------------------------------------------------
# Model lifecycle
# ------------------------------------------------------------------
def _load(self, name: str):
if name not in MODELS:
raise ValueError(f"Unknown model '{name}'. Choose from: {list(MODELS.keys())}")
if name == self._loaded_name:
return
self._unload()
model_path = str(MODELS[name])
is_cuda = self.device.type == "cuda"
self._processor = WhisperProcessor.from_pretrained(model_path) # type: ignore[assignment]
# Try optimal load: flash attention 2 + float16 on CUDA.
model = self._load_model(model_path, is_cuda)
model.config.use_cache = True
model.generation_config.forced_decoder_ids = None
if not is_cuda and next(model.parameters()).device.type != "cpu":
model.to(self.device) # type: ignore[arg-type]
model.eval()
# BetterTransformer fallback when Flash Attention is unavailable.
if is_cuda and not getattr(model.config, "_attn_implementation", None) == "flash_attention_2":
try:
model = model.to_bettertransformer() # type: ignore[assignment]
log.info("Using BetterTransformer (SDPA kernels).")
except Exception:
log.info("BetterTransformer unavailable, using default attention.")
# torch.compile for graph-level optimization (warmup on first call).
if is_cuda and hasattr(torch, "compile"):
try:
model = torch.compile(model, mode="reduce-overhead") # type: ignore[assignment]
log.info("Model compiled with torch.compile (reduce-overhead).")
except Exception:
log.info("torch.compile unavailable, skipping.")
self._model = model
self._dtype = torch.float16 if is_cuda else torch.float32
self._loaded_name = name
def _load_model(
self, model_path: str, is_cuda: bool,
) -> WhisperForConditionalGeneration:
"""Load model with best available acceleration, falling back gracefully."""
# Attempt 1: Flash Attention 2 + float16 (requires Ampere / sm_80+).
can_flash = (
is_cuda
and torch.cuda.get_device_capability()[0] >= 8
)
if can_flash:
try:
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="auto",
)
except (ImportError, ValueError, RuntimeError) as exc:
log.info("Flash Attention 2 unavailable (%s), trying standard load.", exc)
# Attempt 2: Standard CUDA load (float16, auto device map).
if is_cuda:
try:
return WhisperForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=torch.float16,
device_map="auto",
)
except (ImportError, ValueError, RuntimeError) as exc:
log.info("Auto device_map failed (%s), falling back to manual.", exc)
# Attempt 3: Manual load (CPU or CUDA without device_map).
dtype = torch.float16 if is_cuda else torch.float32
model = WhisperForConditionalGeneration.from_pretrained(
model_path,
quantization_config=None,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
model.to(self.device) # type: ignore[arg-type]
return model
def _unload(self):
del self._model, self._processor
self._model = None
self._processor = None
self._loaded_name = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ------------------------------------------------------------------
# Audio resolution
# ------------------------------------------------------------------
def _resolve_audio(self, inputs) -> np.ndarray: # type: ignore[override]
"""Accept base64 string or raw bytes."""
if isinstance(inputs, str):
raw = base64.b64decode(inputs)
elif isinstance(inputs, bytes):
raw = inputs
else:
raise ValueError("'inputs' must be base64-encoded string or raw bytes.")
return _audio_bytes_to_numpy(raw)
# ------------------------------------------------------------------
# Inference
# ------------------------------------------------------------------
def _transcribe(self, audio: np.ndarray, language: str) -> str:
assert self._model is not None and self._processor is not None
chunks = _chunk(audio)
# Batch all chunks into a single forward pass.
if len(chunks) > 1:
return self._transcribe_batched(chunks, language)
return self._transcribe_single(chunks[0], language)
def _transcribe_single(self, audio: np.ndarray, language: str) -> str:
assert self._model is not None and self._processor is not None
features = self._processor( # type: ignore[operator]
audio, sampling_rate=SAMPLE_RATE, return_tensors="pt",
)
input_features = features.input_features.to(self.device, dtype=self._dtype)
with torch.no_grad(), torch.autocast(
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
):
ids = self._model.generate(
input_features,
language=language,
task="transcribe",
max_new_tokens=440,
)
return self._processor.batch_decode( # type: ignore[union-attr]
ids, skip_special_tokens=True,
)[0].strip()
def _transcribe_batched(self, chunks: list[np.ndarray], language: str) -> str:
assert self._model is not None and self._processor is not None
# Pad shorter chunks to 30s so mel features align for stacking.
padded = []
for c in chunks:
if len(c) < CHUNK_SAMPLES:
c = np.pad(c, (0, CHUNK_SAMPLES - len(c)))
padded.append(c)
features = self._processor( # type: ignore[operator]
padded, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True,
)
input_features = features.input_features.to(self.device, dtype=self._dtype)
with torch.no_grad(), torch.autocast(
self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
):
ids = self._model.generate(
input_features,
language=language,
task="transcribe",
max_new_tokens=440,
)
texts = self._processor.batch_decode( # type: ignore[union-attr]
ids, skip_special_tokens=True,
)
return " ".join(t.strip() for t in texts if t.strip())
|