Za6na commited on
Commit
3c5114a
·
verified ·
1 Parent(s): 634882b

Upload handler.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. handler.py +293 -0
handler.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Inference Endpoint handler for Kurdish/Persian Whisper ASR.
3
+
4
+ Accepts audio (binary, base64, or filepath) and returns transcribed text.
5
+ Default model: whisper-largev3 full fine-tune.
6
+ """
7
+
8
+ import base64
9
+ import gc
10
+ import io
11
+ import logging
12
+ from pathlib import Path
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torchaudio
17
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+ SAMPLE_RATE = 16_000
22
+ CHUNK_SECONDS = 30
23
+ CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE
24
+
25
+ MODELS = {
26
+ "small": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted",
27
+ "full": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full",
28
+ }
29
+
30
+ DEFAULT_MODEL = "full"
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Audio helpers
35
+ # ---------------------------------------------------------------------------
36
+
37
+ def _audio_bytes_to_numpy(raw: bytes) -> np.ndarray:
38
+ """Convert raw audio bytes to float32 mono 16 kHz numpy array.
39
+
40
+ Uses torchaudio (in-memory) instead of shelling out to ffmpeg.
41
+ """
42
+ buf = io.BytesIO(raw)
43
+ waveform, sr = torchaudio.load(buf) # (channels, samples)
44
+
45
+ # Mix to mono.
46
+ if waveform.shape[0] > 1:
47
+ waveform = waveform.mean(dim=0, keepdim=True)
48
+
49
+ # Resample if needed.
50
+ if sr != SAMPLE_RATE:
51
+ waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
52
+
53
+ return waveform.squeeze(0).numpy()
54
+
55
+
56
+ def _chunk(audio: np.ndarray) -> list[np.ndarray]:
57
+ if len(audio) <= CHUNK_SAMPLES:
58
+ return [audio]
59
+ return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)]
60
+
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # Endpoint handler
64
+ # ---------------------------------------------------------------------------
65
+
66
+ class EndpointHandler:
67
+ """
68
+ HuggingFace Inference Endpoint handler.
69
+
70
+ Request format:
71
+ {
72
+ "inputs": <base64-encoded audio OR raw bytes>,
73
+ "parameters": {
74
+ "model": "full" | "small", # default: "full"
75
+ "language": "fa" # default: "fa"
76
+ }
77
+ }
78
+
79
+ Response format:
80
+ {"text": "transcribed text here"}
81
+ """
82
+
83
+ def __init__(self, path: str = ""):
84
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ self._model: WhisperForConditionalGeneration | None = None
86
+ self._processor: WhisperProcessor | None = None
87
+ self._loaded_name: str | None = None
88
+ self._dtype = torch.float32
89
+
90
+ # If HF Inference Endpoint provides a path with model files, use it.
91
+ if path and (Path(path) / "config.json").exists():
92
+ MODELS["full"] = Path(path)
93
+
94
+ self._load(DEFAULT_MODEL)
95
+
96
+ def __call__(self, data: dict) -> dict:
97
+ inputs = data.get("inputs")
98
+ params = data.get("parameters", {}) or {}
99
+ model_name = params.get("model", DEFAULT_MODEL)
100
+ language = params.get("language", "fa")
101
+
102
+ if not inputs:
103
+ return {"error": "No audio provided in 'inputs'."}
104
+
105
+ if model_name != self._loaded_name:
106
+ self._load(model_name)
107
+
108
+ audio = self._resolve_audio(inputs)
109
+ text = self._transcribe(audio, language)
110
+
111
+ return {"text": text}
112
+
113
+ # ------------------------------------------------------------------
114
+ # Model lifecycle
115
+ # ------------------------------------------------------------------
116
+
117
+ def _load(self, name: str):
118
+ if name not in MODELS:
119
+ raise ValueError(f"Unknown model '{name}'. Choose from: {list(MODELS.keys())}")
120
+
121
+ if name == self._loaded_name:
122
+ return
123
+
124
+ self._unload()
125
+ model_path = str(MODELS[name])
126
+ is_cuda = self.device.type == "cuda"
127
+
128
+ self._processor = WhisperProcessor.from_pretrained(model_path) # type: ignore[assignment]
129
+
130
+ # Try optimal load: flash attention 2 + float16 on CUDA.
131
+ model = self._load_model(model_path, is_cuda)
132
+
133
+ model.config.use_cache = True
134
+ model.generation_config.forced_decoder_ids = None
135
+
136
+ if not is_cuda and next(model.parameters()).device.type != "cpu":
137
+ model.to(self.device) # type: ignore[arg-type]
138
+
139
+ model.eval()
140
+
141
+ # BetterTransformer fallback when Flash Attention is unavailable.
142
+ if is_cuda and not getattr(model.config, "_attn_implementation", None) == "flash_attention_2":
143
+ try:
144
+ model = model.to_bettertransformer() # type: ignore[assignment]
145
+ log.info("Using BetterTransformer (SDPA kernels).")
146
+ except Exception:
147
+ log.info("BetterTransformer unavailable, using default attention.")
148
+
149
+ # torch.compile for graph-level optimization (warmup on first call).
150
+ if is_cuda and hasattr(torch, "compile"):
151
+ try:
152
+ model = torch.compile(model, mode="reduce-overhead") # type: ignore[assignment]
153
+ log.info("Model compiled with torch.compile (reduce-overhead).")
154
+ except Exception:
155
+ log.info("torch.compile unavailable, skipping.")
156
+
157
+ self._model = model
158
+ self._dtype = torch.float16 if is_cuda else torch.float32
159
+ self._loaded_name = name
160
+
161
+ def _load_model(
162
+ self, model_path: str, is_cuda: bool,
163
+ ) -> WhisperForConditionalGeneration:
164
+ """Load model with best available acceleration, falling back gracefully."""
165
+ # Attempt 1: Flash Attention 2 + float16 (requires Ampere / sm_80+).
166
+ can_flash = (
167
+ is_cuda
168
+ and torch.cuda.get_device_capability()[0] >= 8
169
+ )
170
+ if can_flash:
171
+ try:
172
+ return WhisperForConditionalGeneration.from_pretrained(
173
+ model_path,
174
+ torch_dtype=torch.float16,
175
+ attn_implementation="flash_attention_2",
176
+ device_map="auto",
177
+ )
178
+ except (ImportError, ValueError, RuntimeError) as exc:
179
+ log.info("Flash Attention 2 unavailable (%s), trying standard load.", exc)
180
+
181
+ # Attempt 2: Standard CUDA load (float16, auto device map).
182
+ if is_cuda:
183
+ try:
184
+ return WhisperForConditionalGeneration.from_pretrained(
185
+ model_path,
186
+ torch_dtype=torch.float16,
187
+ device_map="auto",
188
+ )
189
+ except (ImportError, ValueError, RuntimeError) as exc:
190
+ log.info("Auto device_map failed (%s), falling back to manual.", exc)
191
+
192
+ # Attempt 3: Manual load (CPU or CUDA without device_map).
193
+ dtype = torch.float16 if is_cuda else torch.float32
194
+ model = WhisperForConditionalGeneration.from_pretrained(
195
+ model_path,
196
+ quantization_config=None,
197
+ torch_dtype=dtype,
198
+ low_cpu_mem_usage=True,
199
+ )
200
+ model.to(self.device) # type: ignore[arg-type]
201
+ return model
202
+
203
+ def _unload(self):
204
+ del self._model, self._processor
205
+ self._model = None
206
+ self._processor = None
207
+ self._loaded_name = None
208
+ gc.collect()
209
+ if torch.cuda.is_available():
210
+ torch.cuda.empty_cache()
211
+
212
+ # ------------------------------------------------------------------
213
+ # Audio resolution
214
+ # ------------------------------------------------------------------
215
+
216
+ def _resolve_audio(self, inputs) -> np.ndarray: # type: ignore[override]
217
+ """Accept base64 string or raw bytes."""
218
+ if isinstance(inputs, str):
219
+ raw = base64.b64decode(inputs)
220
+ elif isinstance(inputs, bytes):
221
+ raw = inputs
222
+ else:
223
+ raise ValueError("'inputs' must be base64-encoded string or raw bytes.")
224
+
225
+ return _audio_bytes_to_numpy(raw)
226
+
227
+ # ------------------------------------------------------------------
228
+ # Inference
229
+ # ------------------------------------------------------------------
230
+
231
+ def _transcribe(self, audio: np.ndarray, language: str) -> str:
232
+ assert self._model is not None and self._processor is not None
233
+
234
+ chunks = _chunk(audio)
235
+
236
+ # Batch all chunks into a single forward pass.
237
+ if len(chunks) > 1:
238
+ return self._transcribe_batched(chunks, language)
239
+
240
+ return self._transcribe_single(chunks[0], language)
241
+
242
+ def _transcribe_single(self, audio: np.ndarray, language: str) -> str:
243
+ assert self._model is not None and self._processor is not None
244
+
245
+ features = self._processor( # type: ignore[operator]
246
+ audio, sampling_rate=SAMPLE_RATE, return_tensors="pt",
247
+ )
248
+ input_features = features.input_features.to(self.device, dtype=self._dtype)
249
+
250
+ with torch.no_grad(), torch.autocast(
251
+ self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
252
+ ):
253
+ ids = self._model.generate(
254
+ input_features,
255
+ language=language,
256
+ task="transcribe",
257
+ max_new_tokens=440,
258
+ )
259
+
260
+ return self._processor.batch_decode( # type: ignore[union-attr]
261
+ ids, skip_special_tokens=True,
262
+ )[0].strip()
263
+
264
+ def _transcribe_batched(self, chunks: list[np.ndarray], language: str) -> str:
265
+ assert self._model is not None and self._processor is not None
266
+
267
+ # Pad shorter chunks to 30s so mel features align for stacking.
268
+ padded = []
269
+ for c in chunks:
270
+ if len(c) < CHUNK_SAMPLES:
271
+ c = np.pad(c, (0, CHUNK_SAMPLES - len(c)))
272
+ padded.append(c)
273
+
274
+ features = self._processor( # type: ignore[operator]
275
+ padded, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True,
276
+ )
277
+ input_features = features.input_features.to(self.device, dtype=self._dtype)
278
+
279
+ with torch.no_grad(), torch.autocast(
280
+ self.device.type, dtype=torch.float16, enabled=self.device.type == "cuda",
281
+ ):
282
+ ids = self._model.generate(
283
+ input_features,
284
+ language=language,
285
+ task="transcribe",
286
+ max_new_tokens=440,
287
+ )
288
+
289
+ texts = self._processor.batch_decode( # type: ignore[union-attr]
290
+ ids, skip_special_tokens=True,
291
+ )
292
+
293
+ return " ".join(t.strip() for t in texts if t.strip())