Za6na commited on
Commit
27b824a
·
verified ·
1 Parent(s): cbc21a0

Upload demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo.py +289 -0
demo.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio demo for fine-tuned Whisper models — Kurdish Sorani & Persian transcription.
3
+ """
4
+
5
+ import gc
6
+ import time
7
+ from pathlib import Path
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ from transformers import WhisperForConditionalGeneration, WhisperProcessor
13
+
14
+ # ---------------------------------------------------------------------------
15
+ # Model registry
16
+ # ---------------------------------------------------------------------------
17
+ MODELS = {
18
+ "Small (whisper-small, PEFT-merged)": {
19
+ "path": Path(__file__).parent / "models" / "whisper-small-peft-kurdish-on-persian-converted",
20
+ "base": "openai/whisper-small",
21
+ },
22
+ "Large-v3 (full fine-tune)": {
23
+ "path": Path(__file__).parent / "models" / "whisper-largev3-on-persian-centralkurdish-full",
24
+ "base": "openai/whisper-large-v3",
25
+ },
26
+ }
27
+
28
+ LANGUAGES = {
29
+ "Kurdish Sorani (کوردی سۆرانی)": "fa", # no native <|ku|>; models trained with <|fa|>
30
+ "Persian (فارسی)": "fa",
31
+ }
32
+
33
+ SAMPLE_RATE = 16_000
34
+ CHUNK_SECONDS = 30
35
+ CHUNK_SAMPLES = CHUNK_SECONDS * SAMPLE_RATE
36
+
37
+
38
+ # ---------------------------------------------------------------------------
39
+ # ModelManager — lazy loading, one model in memory at a time
40
+ # ---------------------------------------------------------------------------
41
+ class ModelManager:
42
+ def __init__(self):
43
+ self.processor: WhisperProcessor | None = None
44
+ self.model: WhisperForConditionalGeneration | None = None
45
+ self.current_name: str | None = None
46
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ # --- public -----------------------------------------------------------
49
+
50
+ def load(self, name: str) -> str:
51
+ """Load *name*, unloading any previously loaded model first."""
52
+ if name == self.current_name:
53
+ return self._status()
54
+ self._unload()
55
+
56
+ cfg = MODELS[name]
57
+ model_path = str(cfg["path"])
58
+
59
+ self.processor = WhisperProcessor.from_pretrained(model_path)
60
+
61
+ # The small PEFT model was saved with load_in_8bit in its config.
62
+ # bitsandbytes doesn't work on Windows / CPU, so we catch and
63
+ # fall back to float16 (or float32 on CPU).
64
+ try:
65
+ self.model = WhisperForConditionalGeneration.from_pretrained(
66
+ model_path,
67
+ device_map="auto" if self.device.type == "cuda" else None,
68
+ )
69
+ except (ImportError, ValueError, RuntimeError):
70
+ # Quantisation failed — reload without it.
71
+ dtype = torch.float16 if self.device.type == "cuda" else torch.float32
72
+ self.model = WhisperForConditionalGeneration.from_pretrained(
73
+ model_path,
74
+ quantization_config=None,
75
+ torch_dtype=dtype,
76
+ low_cpu_mem_usage=True,
77
+ )
78
+ self.model.to(self.device)
79
+
80
+ # Ensure generate uses KV-cache regardless of saved config.
81
+ self.model.config.use_cache = True
82
+
83
+ # Clear stale forced_decoder_ids so they don't conflict
84
+ # with the language/task kwargs we pass to generate().
85
+ self.model.generation_config.forced_decoder_ids = None
86
+
87
+ if self.device.type != "cuda" and next(self.model.parameters()).device.type != "cpu":
88
+ self.model.to(self.device)
89
+
90
+ self.model.eval()
91
+ self._dtype = next(self.model.parameters()).dtype
92
+ self.current_name = name
93
+ return self._status()
94
+
95
+ def generate(self, audio: np.ndarray, language_code: str) -> str:
96
+ """Run inference on a float32 mono 16 kHz numpy array."""
97
+ if self.model is None or self.processor is None:
98
+ raise RuntimeError("No model loaded.")
99
+
100
+ chunks = self._chunk(audio)
101
+ parts: list[str] = []
102
+
103
+ for chunk in chunks:
104
+ inputs = self.processor(
105
+ chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt",
106
+ )
107
+ input_features = inputs.input_features.to(self.device, dtype=self._dtype)
108
+
109
+ with torch.no_grad():
110
+ predicted_ids = self.model.generate(
111
+ input_features,
112
+ language=language_code,
113
+ task="transcribe",
114
+ max_new_tokens=440,
115
+ )
116
+
117
+ text = self.processor.batch_decode(
118
+ predicted_ids, skip_special_tokens=True,
119
+ )[0].strip()
120
+ if text:
121
+ parts.append(text)
122
+
123
+ return " ".join(parts)
124
+
125
+ # --- private ----------------------------------------------------------
126
+
127
+ def _unload(self):
128
+ if self.model is not None:
129
+ del self.model
130
+ self.model = None
131
+ if self.processor is not None:
132
+ del self.processor
133
+ self.processor = None
134
+ self.current_name = None
135
+ gc.collect()
136
+ if torch.cuda.is_available():
137
+ torch.cuda.empty_cache()
138
+
139
+ def _status(self) -> str:
140
+ mem = ""
141
+ if torch.cuda.is_available():
142
+ allocated = torch.cuda.memory_allocated() / 1024**3
143
+ mem = f" | GPU memory: {allocated:.1f} GB"
144
+ return f"{self.current_name} • {self.device}{mem}"
145
+
146
+ @staticmethod
147
+ def _chunk(audio: np.ndarray) -> list[np.ndarray]:
148
+ if len(audio) <= CHUNK_SAMPLES:
149
+ return [audio]
150
+ return [audio[i : i + CHUNK_SAMPLES] for i in range(0, len(audio), CHUNK_SAMPLES)]
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Audio normalisation helper
155
+ # ---------------------------------------------------------------------------
156
+ def prepare_audio(audio) -> np.ndarray:
157
+ """Accept a filepath from Gradio and return float32 mono 16 kHz numpy array."""
158
+ import subprocess
159
+ import tempfile
160
+
161
+ if not audio:
162
+ raise gr.Error("No audio provided — please record or upload a file first.")
163
+
164
+ audio_path = Path(audio)
165
+ if not audio_path.exists():
166
+ raise gr.Error(f"Audio file not found: {audio}")
167
+
168
+ # Convert any format to 16 kHz mono WAV via ffmpeg, then load the raw PCM.
169
+ # This handles ogg, webm, mp3, flac, m4a, opus — anything ffmpeg supports.
170
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
171
+ wav_path = tmp.name
172
+
173
+ try:
174
+ subprocess.run(
175
+ [
176
+ "ffmpeg", "-y", "-i", str(audio_path),
177
+ "-ar", str(SAMPLE_RATE),
178
+ "-ac", "1",
179
+ "-c:a", "pcm_s16le",
180
+ wav_path,
181
+ ],
182
+ capture_output=True,
183
+ check=True,
184
+ )
185
+ import soundfile as sf
186
+ data, _ = sf.read(wav_path, dtype="float32")
187
+ finally:
188
+ Path(wav_path).unlink(missing_ok=True)
189
+
190
+ return data
191
+
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # Gradio callback
195
+ # ---------------------------------------------------------------------------
196
+ manager = ModelManager()
197
+
198
+
199
+ def transcribe(audio_mic, audio_file, model_name: str, language: str):
200
+ # Prefer uploaded file; fall back to microphone recording.
201
+ audio = audio_file if audio_file is not None else audio_mic
202
+
203
+ if model_name not in MODELS:
204
+ raise gr.Error("Please select a model.")
205
+
206
+ status = manager.load(model_name)
207
+ lang_code = LANGUAGES[language]
208
+
209
+ t0 = time.perf_counter()
210
+ text = manager.generate(prepare_audio(audio), lang_code)
211
+ elapsed = time.perf_counter() - t0
212
+
213
+ status += f" | {elapsed:.1f}s"
214
+ return text, status
215
+
216
+
217
+ # ---------------------------------------------------------------------------
218
+ # UI
219
+ # ---------------------------------------------------------------------------
220
+ RTL_CSS = """
221
+ #output-box textarea {
222
+ direction: rtl;
223
+ text-align: right;
224
+ font-family: 'Vazirmatn', 'Noto Sans Arabic', Tahoma, sans-serif;
225
+ font-size: 1.15rem;
226
+ line-height: 1.9;
227
+ }
228
+ """
229
+
230
+
231
+ def build_ui() -> gr.Blocks:
232
+ with gr.Blocks(title="Whisper Kurdish & Persian") as app:
233
+ gr.Markdown("## Whisper — Kurdish Sorani & Persian Transcription")
234
+
235
+ with gr.Row():
236
+ model_dd = gr.Dropdown(
237
+ choices=list(MODELS.keys()),
238
+ value=list(MODELS.keys())[0],
239
+ label="Model",
240
+ )
241
+ lang_dd = gr.Dropdown(
242
+ choices=list(LANGUAGES.keys()),
243
+ value=list(LANGUAGES.keys())[0],
244
+ label="Language",
245
+ )
246
+
247
+ with gr.Row():
248
+ audio_mic = gr.Audio(
249
+ label="Record from microphone",
250
+ sources=["microphone"],
251
+ type="filepath",
252
+ )
253
+ audio_file = gr.File(
254
+ label="Or upload audio file (wav, ogg, mp3, flac, m4a, opus …)",
255
+ file_types=[".wav", ".ogg", ".oga", ".mp3", ".flac", ".m4a",
256
+ ".opus", ".webm", ".wma", ".aac", ".amr"],
257
+ )
258
+
259
+ btn = gr.Button("Transcribe", variant="primary")
260
+
261
+ output = gr.Textbox(
262
+ label="Transcription",
263
+ lines=6,
264
+ buttons=["copy"],
265
+ elem_id="output-box",
266
+ rtl=True,
267
+ )
268
+ status = gr.Textbox(label="Status", interactive=False, lines=1)
269
+
270
+ btn.click(
271
+ fn=transcribe,
272
+ inputs=[audio_mic, audio_file, model_dd, lang_dd],
273
+ outputs=[output, status],
274
+ )
275
+
276
+ return app
277
+
278
+
279
+ # ---------------------------------------------------------------------------
280
+ # Entry
281
+ # ---------------------------------------------------------------------------
282
+ if __name__ == "__main__":
283
+ build_ui().launch(
284
+ server_name="0.0.0.0",
285
+ server_port=7865,
286
+ show_error=True,
287
+ theme=gr.themes.Soft(),
288
+ css=RTL_CSS,
289
+ )