sbompolas commited on
Commit
2bfc660
·
verified ·
1 Parent(s): 57f7f52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -458
app.py CHANGED
@@ -1,8 +1,8 @@
1
  import gradio as gr
2
  import torch
3
- import logging
4
  import gc
5
  import time
 
6
  from transformers import (
7
  pipeline,
8
  AutoProcessor,
@@ -12,548 +12,243 @@ from transformers import (
12
  WhisperProcessor,
13
  )
14
 
15
- # Try to import flash attention capability (only relevant for some seq2seq models)
16
- try:
17
- from transformers.utils import is_flash_attn_2_available
18
- FLASH_ATTN_AVAILABLE = True
19
- except Exception:
20
- FLASH_ATTN_AVAILABLE = False
21
- def is_flash_attn_2_available():
22
- return False
23
-
24
  logging.basicConfig(level=logging.INFO)
25
  logger = logging.getLogger(__name__)
26
 
27
 
28
  class MultiASRApp:
29
- """
30
- Supports BOTH:
31
- - Whisper / seq2seq ASR (openai/whisper-*, fine-tuned whisper)
32
- - XLS-R / Wav2Vec2 CTC ASR (e.g., ilsp/xls-r-greek-cretan)
33
- """
34
-
35
  def __init__(self):
36
  self.pipe = None
37
  self.current_model = None
38
  self.current_kind = None # "whisper" | "ctc"
39
 
40
  self.available_models = [
41
- "openai/whisper-tiny",
42
- "openai/whisper-base",
43
  "openai/whisper-small",
44
  "openai/whisper-medium",
45
- "openai/whisper-large-v2",
46
- "openai/whisper-large-v3",
47
  "ilsp/whisper_greek_dialect_of_lesbos",
48
  "ilsp/xls-r-greek-cretan",
49
  ]
50
 
51
- # ----------------------------
52
- # Model classification
53
- # ----------------------------
54
- def detect_model_kind(self, model_name: str) -> str:
55
- """
56
- Decide which loading path to use.
57
- - Whisper models -> seq2seq
58
- - XLS-R / wav2vec2 CTC -> ctc
59
- """
60
  name = model_name.lower()
61
-
62
- # Your known XLS-R model:
63
  if "xls-r" in name or "xlsr" in name:
64
  return "ctc"
65
-
66
- # Heuristic: Whisper is usually named whisper
67
- if "whisper" in name:
68
- return "whisper"
69
-
70
- # Fallback: try whisper first (safer for your list), else ctc
71
  return "whisper"
72
 
73
- def is_fine_tuned_whisper(self, model_name: str) -> bool:
74
- """
75
- Fine-tuned whisper models may need conservative settings.
76
- (This is NOT for XLS-R.)
77
- """
78
- n = model_name.lower()
79
- indicators = ["ilsp/", "dialect", "fine", "custom"]
80
- return any(x in n for x in indicators) and ("whisper" in n)
81
 
82
- # ----------------------------
83
- # Pipeline creation
84
- # ----------------------------
85
- def _pick_device_and_dtype(self, kind: str, conservative: bool):
86
  if torch.cuda.is_available():
87
- device = "cuda:0"
88
- # For CTC, fp16 can work, but fp32 is often safer across community models.
89
- if kind == "ctc":
90
- torch_dtype = torch.float32
91
- else:
92
- torch_dtype = torch.float32 if conservative else torch.float16
93
- else:
94
- device = "cpu"
95
- torch_dtype = torch.float32
96
- return device, torch_dtype
97
 
98
- def create_whisper_pipe(self, model_name: str, use_flash_attention: bool = True):
 
 
 
99
  conservative = self.is_fine_tuned_whisper(model_name)
100
- device, torch_dtype = self._pick_device_and_dtype("whisper", conservative)
101
-
102
- logger.info(f"[WHISPER] Loading {model_name} on {device} dtype={torch_dtype} conservative={conservative}")
103
-
104
- # Flash attention is only meaningful for some GPU seq2seq configs
105
- attn_implementation = "eager"
106
- if (
107
- use_flash_attention
108
- and not conservative
109
- and FLASH_ATTN_AVAILABLE
110
- and is_flash_attn_2_available()
111
- and torch.cuda.is_available()
112
- ):
113
- attn_implementation = "flash_attention_2"
114
- logger.info("[WHISPER] Using flash_attention_2")
115
-
116
- # Some fine-tuned repos are saved as WhisperForConditionalGeneration; others as generic SpeechSeq2Seq
117
  try:
118
  model = WhisperForConditionalGeneration.from_pretrained(
119
  model_name,
120
- torch_dtype=torch_dtype,
121
  low_cpu_mem_usage=True,
122
- cache_dir="./cache",
123
  )
124
- processor = WhisperProcessor.from_pretrained(model_name, cache_dir="./cache")
125
- except Exception as e:
126
- logger.info(f"[WHISPER] WhisperForConditionalGeneration load failed ({e}); trying AutoModelForSpeechSeq2Seq")
127
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
128
  model_name,
129
- torch_dtype=torch_dtype,
130
  low_cpu_mem_usage=True,
131
- use_safetensors=not conservative,
132
- attn_implementation=attn_implementation,
133
- cache_dir="./cache",
134
  )
135
- processor = AutoProcessor.from_pretrained(model_name, cache_dir="./cache")
136
 
137
  model.to(device)
138
 
139
  return pipeline(
140
- task="automatic-speech-recognition",
141
  model=model,
142
  tokenizer=processor.tokenizer,
143
  feature_extractor=processor.feature_extractor,
144
  device=device,
145
- torch_dtype=torch_dtype,
146
- # chunking is supported for whisper
147
- chunk_length_s=30 if conservative else None,
148
  )
149
 
150
- def create_ctc_pipe(self, model_name: str):
151
- """
152
- XLS-R / Wav2Vec2 CTC path.
153
- Key differences:
154
- - AutoModelForCTC
155
- - No generate_kwargs (CTC decoding)
156
- - Timestamps are typically NOT supported in the same way as Whisper chunks
157
- """
158
- device, torch_dtype = self._pick_device_and_dtype("ctc", conservative=True)
159
-
160
- logger.info(f"[CTC] Loading {model_name} on {device} dtype={torch_dtype}")
161
 
162
- processor = AutoProcessor.from_pretrained(model_name, cache_dir="./cache")
163
  model = AutoModelForCTC.from_pretrained(
164
  model_name,
165
- torch_dtype=torch_dtype,
166
  low_cpu_mem_usage=True,
167
- cache_dir="./cache",
168
  )
169
  model.to(device)
170
 
171
- # Pipeline can take tokenizer + feature_extractor if present
172
- tokenizer = getattr(processor, "tokenizer", None)
173
- feature_extractor = getattr(processor, "feature_extractor", None)
174
-
175
  return pipeline(
176
- task="automatic-speech-recognition",
177
  model=model,
178
- tokenizer=tokenizer,
179
- feature_extractor=feature_extractor,
180
  device=device,
181
- torch_dtype=torch_dtype,
182
- # For long audio, CTC pipelines can also chunk; keep conservative defaults.
183
  chunk_length_s=20,
184
- stride_length_s=(4, 2), # helps continuity between chunks
185
  )
186
 
187
- def create_pipe(self, model_name: str, use_flash_attention: bool = True):
188
- kind = self.detect_model_kind(model_name)
189
- if kind == "ctc":
190
- return self.create_ctc_pipe(model_name), "ctc"
191
- else:
192
- # Disable flash attention automatically for fine-tuned whisper
193
- if self.is_fine_tuned_whisper(model_name):
194
- use_flash_attention = False
195
- return self.create_whisper_pipe(model_name, use_flash_attention=use_flash_attention), "whisper"
196
-
197
- # ----------------------------
198
- # Load / unload
199
- # ----------------------------
200
- def clear_model(self):
201
- if self.pipe is not None:
202
- try:
203
- del self.pipe
204
- except Exception:
205
- pass
206
- self.pipe = None
207
- self.current_model = None
208
- self.current_kind = None
209
-
210
- if torch.cuda.is_available():
211
- torch.cuda.empty_cache()
212
- gc.collect()
213
-
214
- def load_model(self, model_name: str, use_flash_attention: bool = True) -> bool:
215
  if self.current_model == model_name and self.pipe is not None:
216
- logger.info("Model already loaded")
217
  return True
218
 
219
- logger.info(f"Loading new model: {model_name}")
220
  self.clear_model()
 
221
 
222
  try:
223
- pipe, kind = self.create_pipe(model_name, use_flash_attention=use_flash_attention)
224
- self.pipe = pipe
 
 
 
225
  self.current_model = model_name
226
  self.current_kind = kind
227
- logger.info(f"Loaded {model_name} as {kind}")
228
  return True
229
  except Exception as e:
230
- logger.error(f"Error loading model {model_name}: {e}", exc_info=True)
231
  self.clear_model()
232
  return False
233
 
234
- # ----------------------------
235
- # Transcription
236
- # ----------------------------
237
- def transcribe_audio(
238
- self,
239
- audio_file,
240
- model_name="openai/whisper-small",
241
- language="Automatic Detection",
242
- task="transcribe",
243
- chunk_length_s=30,
244
- batch_size=4,
245
- use_flash_attention=False,
246
- return_timestamps=True,
247
- ):
248
- if audio_file is None:
249
- return "Please upload an audio file", "", ""
250
-
251
- start_time = time.time()
252
-
253
- ok = self.load_model(model_name, use_flash_attention=use_flash_attention)
254
- if not ok:
255
- return "Failed to load model", "", "Failed to load model"
256
-
257
- kind = self.current_kind or self.detect_model_kind(model_name)
258
 
259
- try:
260
- if kind == "ctc":
261
- # XLS-R / CTC: no generate_kwargs; timestamps usually not available as chunks.
262
- if return_timestamps:
263
- ts_note = (
264
- "=== TIMESTAMPS ===\n"
265
- "Timestamps are not provided for XLS-R/CTC models in this demo.\n"
266
- )
267
- else:
268
- ts_note = "=== TIMESTAMPS ===\nDisabled.\n"
269
-
270
- # For CTC, use chunk settings already set in the pipeline; batch_size works but keep conservative
271
- out = self.pipe(
272
- audio_file,
273
- batch_size=min(int(batch_size), 4),
274
- )
275
- text = out.get("text", "") if isinstance(out, dict) else str(out)
276
-
277
- total = time.time() - start_time
278
- details = self._format_detailed_output(
279
- transcription=text,
280
- model_name=model_name,
281
- language=language,
282
- task=task,
283
- transcription_time=total,
284
- chunk_length_s=chunk_length_s,
285
- batch_size=batch_size,
286
- use_flash_attention=False,
287
- num_chunks=0,
288
- model_kind="XLS-R / CTC",
289
- timestamps_supported=False,
290
- )
291
- return text.strip(), ts_note, details
292
-
293
- # ---------------- Whisper / seq2seq ----------------
294
- generate_kwargs = {}
295
-
296
- if language != "Automatic Detection" and not model_name.endswith(".en"):
297
- language_map = {
298
- "Greek": "greek",
299
- "English": "english",
300
- "Spanish": "spanish",
301
- "French": "french",
302
- "German": "german",
303
- "Italian": "italian",
304
- }
305
- generate_kwargs["language"] = language_map.get(language, language.lower())
306
-
307
- if not model_name.endswith(".en"):
308
- generate_kwargs["task"] = task
309
-
310
- # Fine-tuned whisper: more conservative runtime params
311
- conservative = self.is_fine_tuned_whisper(model_name)
312
- if conservative:
313
- chunk_length_s = min(int(chunk_length_s), 30)
314
- batch_size = min(int(batch_size), 2)
315
- # more deterministic defaults
316
- generate_kwargs.update({
317
- "do_sample": False,
318
- "num_beams": 1,
319
- "max_length": 448,
320
- })
321
-
322
- out = self.pipe(
323
- audio_file,
324
- chunk_length_s=int(chunk_length_s),
325
- batch_size=int(batch_size),
326
- generate_kwargs=generate_kwargs,
327
- return_timestamps=bool(return_timestamps),
328
  )
329
 
330
- total = time.time() - start_time
331
- text = out.get("text", "") if isinstance(out, dict) else str(out)
332
- chunks = out.get("chunks", []) if isinstance(out, dict) else []
333
-
334
- ts_text = ""
335
- if return_timestamps:
336
- ts_text = self._format_timestamps(chunks) if chunks else "=== TIMESTAMPS ===\nNo chunks returned.\n"
337
- else:
338
- ts_text = "=== TIMESTAMPS ===\nDisabled.\n"
339
-
340
- details = self._format_detailed_output(
341
- transcription=text,
342
- model_name=model_name,
343
- language=language,
344
- task=task,
345
- transcription_time=total,
346
- chunk_length_s=chunk_length_s,
347
- batch_size=batch_size,
348
- use_flash_attention=use_flash_attention and not conservative,
349
- num_chunks=len(chunks),
350
- model_kind="Whisper / Seq2Seq" + (" (fine-tuned)" if conservative else ""),
351
- timestamps_supported=True,
352
  )
 
 
353
 
354
- return text.strip(), ts_text, details
 
 
 
355
 
356
- except Exception as e:
357
- logger.error(f"Transcription error: {e}", exc_info=True)
358
- msg = f"Transcription error: {str(e)}"
359
- return msg, "", msg
360
-
361
- # ----------------------------
362
- # Formatting helpers
363
- # ----------------------------
364
- def _format_timestamps(self, chunks):
365
- txt = "=== TIMESTAMPS ===\n"
366
- for i, ch in enumerate(chunks or []):
367
- try:
368
- ts = ch.get("timestamp", None)
369
- t = ch.get("text", "")
370
- if isinstance(ts, (list, tuple)) and len(ts) >= 2 and ts[0] is not None and ts[1] is not None:
371
- txt += f"[{float(ts[0]):.1f}s - {float(ts[1]):.1f}s]: {t}\n"
372
- else:
373
- txt += f"[Chunk {i}]: {t}\n"
374
- except Exception as e:
375
- txt += f"[Chunk {i} error]: {e}\n"
376
- return txt
377
-
378
- def _format_detailed_output(
379
- self,
380
- transcription,
381
- model_name,
382
- language,
383
- task,
384
- transcription_time,
385
- chunk_length_s,
386
- batch_size,
387
- use_flash_attention,
388
- num_chunks,
389
- model_kind,
390
- timestamps_supported,
391
- ):
392
- out = "=== TRANSCRIPTION ===\n"
393
- out += f"{transcription}\n\n"
394
-
395
- out += "=== MODEL INFORMATION ===\n"
396
- out += f"Model: {model_name}\n"
397
- out += f"Kind: {model_kind}\n"
398
- out += f"Language setting: {language}\n"
399
- out += f"Task: {task}\n"
400
- out += f"Processing time: {transcription_time:.2f} seconds\n"
401
- out += f"Chunks: {num_chunks}\n"
402
- out += f"Timestamps supported: {'Yes' if timestamps_supported else 'No'}\n"
403
-
404
- out += "\n=== SETTINGS ===\n"
405
- out += f"Chunk length (UI): {chunk_length_s} seconds\n"
406
- out += f"Batch size (UI): {batch_size}\n"
407
- out += f"Flash Attention: {'Enabled' if use_flash_attention else 'Disabled'}\n"
408
  return out
409
 
410
- def get_model_info(self):
411
- if self.pipe is None:
412
- return "No model loaded"
413
- try:
414
- device = next(self.pipe.model.parameters()).device
415
- dtype = next(self.pipe.model.parameters()).dtype
416
- return f"✅ {self.current_model} ({self.current_kind}) - {device} ({dtype})"
417
- except Exception:
418
- return f"✅ {self.current_model} ({self.current_kind}) loaded"
419
-
420
-
421
- # ----------------------------
422
- # Gradio wiring
423
- # ----------------------------
424
- logger.info("Initializing Multi ASR App...")
425
- asr_app = MultiASRApp()
426
-
427
-
428
- def transcribe_wrapper(audio, model_name, language, task, chunk_length_s,
429
- batch_size, use_flash_attention, return_timestamps):
430
- return asr_app.transcribe_audio(
431
- audio_file=audio,
432
- model_name=model_name,
433
- language=language,
434
- task=task,
435
- chunk_length_s=chunk_length_s,
436
- batch_size=batch_size,
437
- use_flash_attention=use_flash_attention,
438
- return_timestamps=return_timestamps,
 
439
  )
440
 
 
441
 
442
- def get_model_status():
443
- return asr_app.get_model_info()
444
-
445
-
446
- def update_settings_for_model(model_name):
447
- kind = asr_app.detect_model_kind(model_name)
448
- if kind == "ctc":
449
- # XLS-R recommendations
450
- return {
451
- "batch_size": gr.update(value=1, maximum=4),
452
- "use_flash_attention": gr.update(value=False),
453
- "chunk_length_s": gr.update(value=20),
454
- "return_timestamps": gr.update(value=False),
455
- }
456
- else:
457
- # Whisper recommendations (fine-tuned whisper: conservative)
458
- conservative = asr_app.is_fine_tuned_whisper(model_name)
459
- return {
460
- "batch_size": gr.update(value=1 if conservative else 4, maximum=2 if conservative else 16),
461
- "use_flash_attention": gr.update(value=False),
462
- "chunk_length_s": gr.update(value=30),
463
- "return_timestamps": gr.update(value=True),
464
- }
465
-
466
-
467
- def create_interface():
468
- with gr.Blocks(title="Multi-ASR (Whisper + XLS-R)", theme=gr.themes.Soft()) as interface:
469
- gr.Markdown(
470
- """
471
- # 🚀 Multi-ASR Demo (Whisper + XLS-R)
472
-
473
- This app supports:
474
- - **Whisper** models (seq2seq) incl. fine-tuned dialect Whisper
475
- - **XLS-R** models (CTC) e.g. **ilsp/xls-r-greek-cretan**
476
-
477
- Notes:
478
- - Whisper can return chunk timestamps.
479
- - XLS-R/CTC typically **does not** return timestamps in this pipeline setup.
480
- """
481
- )
482
 
483
- model_status = gr.Textbox(value=get_model_status(), label="🔧 Current Model Status", interactive=False)
484
-
485
- with gr.Row():
486
- with gr.Column():
487
- audio_input = gr.Audio(label="🎵 Upload Audio File", type="filepath")
488
-
489
- model_dropdown = gr.Dropdown(
490
- choices=asr_app.available_models,
491
- value="openai/whisper-small",
492
- label="Model",
493
- info="Automatically switches loading path (Whisper vs XLS-R/CTC).",
494
- )
495
-
496
- with gr.Row():
497
- language_dropdown = gr.Dropdown(
498
- choices=["Automatic Detection", "Greek", "English", "Spanish", "French", "German", "Italian"],
499
- value="Automatic Detection",
500
- label="Language (Whisper only)",
501
- )
502
- task_dropdown = gr.Dropdown(
503
- choices=["transcribe", "translate"],
504
- value="transcribe",
505
- label="Task (Whisper only)",
506
- )
507
-
508
- with gr.Accordion("Advanced Settings", open=False):
509
- chunk_length_s = gr.Slider(10, 60, value=30, step=5, label="Chunk Length (seconds)")
510
- batch_size = gr.Slider(1, 16, value=4, step=1, label="Batch Size")
511
- use_flash_attention = gr.Checkbox(label="Flash Attention 2 (Whisper only)", value=False)
512
- return_timestamps = gr.Checkbox(label="Return Timestamps (Whisper only)", value=True)
513
-
514
- transcribe_btn = gr.Button("🚀 Transcribe", variant="primary", size="lg")
515
-
516
- with gr.Column():
517
- transcription_output = gr.Textbox(label="Transcription", lines=8, show_copy_button=True)
518
-
519
- with gr.Accordion("Timestamps", open=False):
520
- timestamps_output = gr.Textbox(label="Timestamp Information", lines=10, show_copy_button=True)
521
-
522
- with gr.Accordion("Detailed Information", open=False):
523
- detailed_output = gr.Textbox(label="Processing Details & Model Info", lines=15, show_copy_button=True)
524
-
525
- transcribe_btn.click(
526
- fn=transcribe_wrapper,
527
- inputs=[
528
- audio_input,
529
- model_dropdown,
530
- language_dropdown,
531
- task_dropdown,
532
- chunk_length_s,
533
- batch_size,
534
- use_flash_attention,
535
- return_timestamps,
536
- ],
537
- outputs=[transcription_output, timestamps_output, detailed_output],
538
- show_progress=True,
539
- )
540
 
541
- # When model changes, auto-tune UI controls
542
- def on_model_change(m):
543
- rec = update_settings_for_model(m)
544
- kind = asr_app.detect_model_kind(m)
545
- status = f"Model will load on next transcription ({'XLS-R/CTC' if kind=='ctc' else 'Whisper'})"
546
- return status, rec["batch_size"], rec["use_flash_attention"], rec["chunk_length_s"], rec["return_timestamps"]
547
-
548
- model_dropdown.change(
549
- fn=on_model_change,
550
- inputs=[model_dropdown],
551
- outputs=[model_status, batch_size, use_flash_attention, chunk_length_s, return_timestamps],
552
- )
553
 
554
- return interface
555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  if __name__ == "__main__":
558
- interface = create_interface()
559
- interface.launch(share=True)
 
1
  import gradio as gr
2
  import torch
 
3
  import gc
4
  import time
5
+ import logging
6
  from transformers import (
7
  pipeline,
8
  AutoProcessor,
 
12
  WhisperProcessor,
13
  )
14
 
 
 
 
 
 
 
 
 
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
  class MultiASRApp:
 
 
 
 
 
 
20
  def __init__(self):
21
  self.pipe = None
22
  self.current_model = None
23
  self.current_kind = None # "whisper" | "ctc"
24
 
25
  self.available_models = [
 
 
26
  "openai/whisper-small",
27
  "openai/whisper-medium",
 
 
28
  "ilsp/whisper_greek_dialect_of_lesbos",
29
  "ilsp/xls-r-greek-cretan",
30
  ]
31
 
32
+ # ------------------------
33
+ # Model detection
34
+ # ------------------------
35
+ def detect_model_kind(self, model_name):
 
 
 
 
 
36
  name = model_name.lower()
 
 
37
  if "xls-r" in name or "xlsr" in name:
38
  return "ctc"
 
 
 
 
 
 
39
  return "whisper"
40
 
41
+ def is_fine_tuned_whisper(self, model_name):
42
+ return "ilsp/" in model_name.lower() and "whisper" in model_name.lower()
 
 
 
 
 
 
43
 
44
+ # ------------------------
45
+ # Device & dtype
46
+ # ------------------------
47
+ def pick_device(self, conservative=True):
48
  if torch.cuda.is_available():
49
+ return "cuda:0", torch.float32 if conservative else torch.float16
50
+ return "cpu", torch.float32
 
 
 
 
 
 
 
 
51
 
52
+ # ------------------------
53
+ # Pipeline creation
54
+ # ------------------------
55
+ def create_whisper_pipe(self, model_name):
56
  conservative = self.is_fine_tuned_whisper(model_name)
57
+ device, dtype = self.pick_device(conservative)
58
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
  model = WhisperForConditionalGeneration.from_pretrained(
61
  model_name,
62
+ torch_dtype=dtype,
63
  low_cpu_mem_usage=True,
 
64
  )
65
+ processor = WhisperProcessor.from_pretrained(model_name)
66
+ except Exception:
 
67
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
68
  model_name,
69
+ torch_dtype=dtype,
70
  low_cpu_mem_usage=True,
 
 
 
71
  )
72
+ processor = AutoProcessor.from_pretrained(model_name)
73
 
74
  model.to(device)
75
 
76
  return pipeline(
77
+ "automatic-speech-recognition",
78
  model=model,
79
  tokenizer=processor.tokenizer,
80
  feature_extractor=processor.feature_extractor,
81
  device=device,
82
+ torch_dtype=dtype,
83
+ chunk_length_s=30,
 
84
  )
85
 
86
+ def create_ctc_pipe(self, model_name):
87
+ device, dtype = self.pick_device(conservative=True)
 
 
 
 
 
 
 
 
 
88
 
89
+ processor = AutoProcessor.from_pretrained(model_name)
90
  model = AutoModelForCTC.from_pretrained(
91
  model_name,
92
+ torch_dtype=dtype,
93
  low_cpu_mem_usage=True,
 
94
  )
95
  model.to(device)
96
 
 
 
 
 
97
  return pipeline(
98
+ "automatic-speech-recognition",
99
  model=model,
100
+ tokenizer=getattr(processor, "tokenizer", None),
101
+ feature_extractor=getattr(processor, "feature_extractor", None),
102
  device=device,
103
+ torch_dtype=dtype,
 
104
  chunk_length_s=20,
105
+ stride_length_s=(4, 2),
106
  )
107
 
108
+ def load_model(self, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  if self.current_model == model_name and self.pipe is not None:
 
110
  return True
111
 
 
112
  self.clear_model()
113
+ kind = self.detect_model_kind(model_name)
114
 
115
  try:
116
+ if kind == "ctc":
117
+ self.pipe = self.create_ctc_pipe(model_name)
118
+ else:
119
+ self.pipe = self.create_whisper_pipe(model_name)
120
+
121
  self.current_model = model_name
122
  self.current_kind = kind
 
123
  return True
124
  except Exception as e:
125
+ logger.error(e)
126
  self.clear_model()
127
  return False
128
 
129
+ def clear_model(self):
130
+ if self.pipe is not None:
131
+ del self.pipe
132
+ self.pipe = None
133
+ self.current_model = None
134
+ self.current_kind = None
135
+ if torch.cuda.is_available():
136
+ torch.cuda.empty_cache()
137
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # ------------------------
140
+ # Transcription
141
+ # ------------------------
142
+ def transcribe(self, audio, model_name, return_timestamps):
143
+ if audio is None:
144
+ return "Ανέβασε ένα ηχητικό αρχείο.", "", ""
145
+
146
+ start = time.time()
147
+ if not self.load_model(model_name):
148
+ return "Σφάλμα φόρτωσης μοντέλου.", "", ""
149
+
150
+ if self.current_kind == "ctc":
151
+ result = self.pipe(audio)
152
+ text = result.get("text", "")
153
+
154
+ timestamps = (
155
+ "Οι χρονικές σημάνσεις δεν υποστηρίζονται για αυτό το μοντέλο."
156
+ if return_timestamps
157
+ else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  )
159
 
160
+ else:
161
+ result = self.pipe(
162
+ audio,
163
+ return_timestamps=return_timestamps,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  )
165
+ text = result.get("text", "")
166
+ timestamps = self.format_timestamps(result.get("chunks", []))
167
 
168
+ details = (
169
+ f"Μοντέλο: {model_name}\n"
170
+ f"Χρόνος επεξεργασίας: {time.time() - start:.2f} δευτ."
171
+ )
172
 
173
+ return text.strip(), timestamps, details
174
+
175
+ def format_timestamps(self, chunks):
176
+ if not chunks:
177
+ return ""
178
+ out = ""
179
+ for c in chunks:
180
+ ts = c.get("timestamp")
181
+ if ts and ts[0] is not None and ts[1] is not None:
182
+ out += f"[{ts[0]:.1f}–{ts[1]:.1f}] {c.get('text','')}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  return out
184
 
185
+ def status(self):
186
+ if not self.current_model:
187
+ return "Δεν έχει φορτωθεί μοντέλο"
188
+ return f"✔ {self.current_model}"
189
+
190
+
191
+ # ------------------------
192
+ # App
193
+ # ------------------------
194
+ app = MultiASRApp()
195
+
196
+ def run(audio, model, timestamps):
197
+ return app.transcribe(audio, model, timestamps)
198
+
199
+ def status():
200
+ return app.status()
201
+
202
+
203
+ with gr.Blocks(title="Ίντα λαλείς;", theme=gr.themes.Soft()) as demo:
204
+ gr.Markdown(
205
+ """
206
+ # Ίντα λαλείς;
207
+ ## Η Τεχνητή Νοημοσύνη μαθαίνει ελληνικές διαλέκτους
208
+
209
+ 🎧 Ανέβασε ένα ηχητικό αρχείο και δες πώς η Τεχνητή Νοημοσύνη
210
+ αναγνωρίζει την ελληνική γλώσσα και τις διαλέκτους της.
211
+
212
+ 📍 Athens Science Festival 2025
213
+ 🏛 Ωδείο Αθηνών | 18–21 Δεκεμβρίου 2025
214
+ """
215
  )
216
 
217
+ model_status = gr.Textbox(label="Κατάσταση μοντέλου", value=status(), interactive=False)
218
 
219
+ with gr.Row():
220
+ with gr.Column():
221
+ audio = gr.Audio(label="🎵 Ανέβασε ηχητικό αρχείο", type="filepath")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ model = gr.Dropdown(
224
+ choices=app.available_models,
225
+ value="openai/whisper-small",
226
+ label="Μοντέλο αναγνώρισης ομιλίας",
227
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ timestamps = gr.Checkbox(label="Χρονικές σημάνσεις", value=True)
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ btn = gr.Button("🗣️ Μετατροπή ομιλίας σε κείμενο", variant="primary")
232
 
233
+ with gr.Column():
234
+ text_out = gr.Textbox(label="📄 Κείμενο", lines=8, show_copy_button=True)
235
+ ts_out = gr.Textbox(label="Χρονικές σημάνσεις", lines=8)
236
+ info_out = gr.Textbox(label="Πληροφορίες", lines=4)
237
+
238
+ btn.click(
239
+ run,
240
+ inputs=[audio, model, timestamps],
241
+ outputs=[text_out, ts_out, info_out],
242
+ )
243
+
244
+ model.change(lambda _: status(), outputs=model_status)
245
+
246
+ gr.Markdown(
247
+ """
248
+ 🔬 Έρευνα & τεχνολογία για τη γλωσσική ποικιλία
249
+ 🎙️ Η φωνή ως πολιτιστική κληρονομιά
250
+ """
251
+ )
252
 
253
  if __name__ == "__main__":
254
+ demo.launch()