MuhammadHijazii commited on
Commit
be3e6cf
·
verified ·
1 Parent(s): c113304

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -121
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, re, math, json, tempfile
2
  import numpy as np
3
  import pandas as pd
4
  import torch
@@ -9,12 +9,22 @@ from faster_whisper import WhisperModel
9
 
10
  from sentence_transformers import SentenceTransformer, util
11
  from transformers import AutoTokenizer, AutoModel
 
12
 
13
  # =========================
14
- # Device & Lazy-loaded models
15
  # =========================
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
17
 
 
 
 
18
  _SBERT = None
19
  _MARBERT_TOK = None
20
  _MARBERT = None
@@ -26,13 +36,23 @@ def load_models(
26
  whisper_name="small",
27
  whisper_compute="int8"
28
  ):
 
29
  global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER
 
 
 
 
 
 
30
  if _SBERT is None:
31
  _SBERT = SentenceTransformer(sbert_name, device=DEVICE)
32
- if _MARBERT is None:
 
 
33
  _MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
34
  _MARBERT = AutoModel.from_pretrained(marbert_name).to(DEVICE)
35
  _MARBERT.eval()
 
36
  if _WHISPER is None:
37
  _WHISPER = WhisperModel(whisper_name, device=DEVICE, compute_type=whisper_compute)
38
 
@@ -48,12 +68,17 @@ def normalize_ar_orth(text: str) -> str:
48
  return text
49
 
50
  def simple_tokenize(text: str):
51
- import nltk
 
52
  try:
53
- nltk.data.find('tokenizers/punkt')
54
- except LookupError:
55
- nltk.download('punkt')
56
- return nltk.word_tokenize(normalize_ar_orth(text))
 
 
 
 
57
 
58
  def align_texts(ref_tokens, hyp_tokens):
59
  import difflib
@@ -95,7 +120,7 @@ def is_levenshtein_1(w1, w2):
95
  return textdistance.levenshtein(w1, w2) == 1
96
 
97
  # =========================
98
- # Numbers (digits & word-numbers)
99
  # =========================
100
  AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")
101
  UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2,
@@ -136,10 +161,12 @@ def to_numeric_value(token: str):
136
  return words_to_number(toks)
137
 
138
  # =========================
139
- # Semantic similarities (SBERT + MARBERT CLS)
140
  # =========================
141
  def marbert_cls_similarity(a: str, b: str) -> float:
142
  if not a or not b: return 0.0
 
 
143
  with torch.no_grad():
144
  ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
145
  tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
@@ -231,20 +258,17 @@ def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float,
231
  return base_decision
232
 
233
  # =========================
234
- # Pair classifier + main alignment classifier
235
  # =========================
236
  def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
237
  bert_thresh=0.75, max_bert=0.85):
238
- # 1) numbers
239
  ref_num = to_numeric_value(ref_w)
240
  hyp_num = to_numeric_value(hyp_w)
241
  if (ref_num is not None) or (hyp_num is not None):
242
  if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
243
  return 'ASR error (numbers equal)'
244
- # 2) short + lev1
245
  if short_word and lev1:
246
  return 'ASR error (short+lev1)'
247
- # 3) semantic
248
  avg_ok = bert_scores["avg"] >= bert_thresh
249
  max_ok = bert_scores["max"] > max_bert
250
  if ((phon_sim or lev1) and avg_ok) or max_ok:
@@ -254,7 +278,6 @@ def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
254
  def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
255
  bert_thresh=0.75, max_bert=0.85,
256
  asr_token_conf=None, low_high=None):
257
- # thresholds
258
  if low_high is None:
259
  if asr_token_conf:
260
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
@@ -268,8 +291,7 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
268
  else:
269
  low_t, high_t = low_high
270
 
271
- results = []
272
- corrected_words = []
273
 
274
  for entry in aligned:
275
  tag = entry['type']
@@ -285,7 +307,8 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
285
  for k in range(max_len):
286
  ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
287
  hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
288
- if not ref_w and not hyp_w: continue
 
289
 
290
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
291
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
@@ -302,7 +325,6 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
302
  else:
303
  base_status = 'Undefined Case'
304
 
305
- # word-level confidence
306
  word_prob = None; word_dur = None
307
  if (j1 is not None) and (j2 is not None):
308
  hyp_abs_idx = j1 + k
@@ -320,14 +342,13 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
320
  low_t=low_t, high_t=high_t, sbert_lo=0.60
321
  )
322
 
 
323
  if ref_w and hyp_w:
324
  used = ref_w if final_status.startswith("ASR error") else hyp_w
325
  elif hyp_w == '':
326
  used = ''
327
  elif ref_w == '':
328
  used = hyp_w
329
- else:
330
- used = hyp_w
331
 
332
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
333
  f'SBERT={bert_scores["sbert"]:.2f}, '
@@ -347,15 +368,9 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
347
  return results, corrected_text
348
 
349
  # =========================
350
- # Literal / Semantic final scores
351
  # =========================
352
  def literal_similarity(original, recited):
353
- import nltk
354
- try:
355
- nltk.data.find('tokenizers/punkt')
356
- except LookupError:
357
- nltk.download('punkt')
358
-
359
  def norm(t):
360
  t = re.sub(r'[ًٌٍَُِّْـ]', '', t)
361
  t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t)
@@ -363,41 +378,33 @@ def literal_similarity(original, recited):
363
  return t
364
  o = norm(original); r = norm(recited)
365
  lev = textdistance.levenshtein.normalized_similarity(o, r)
366
- ot = nltk.word_tokenize(o); rt = nltk.word_tokenize(r)
367
  common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
368
  word_overlap = common / max(len(ot), 1)
369
- import nltk.translate.bleu_score as bleu
370
- bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
 
 
 
371
  final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
372
  return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
373
  "bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
374
 
375
- def semantic_similarity(original, recited):
376
  sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True),
377
  _SBERT.encode(recited, convert_to_tensor=True)))
378
- with torch.no_grad():
379
- ta = _MARBERT_TOK(original, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
380
- tb = _MARBERT_TOK(recited, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
381
- ea = _MARBERT(**ta).last_hidden_state[:,0,:]
382
- eb = _MARBERT(**tb).last_hidden_state[:,0,:]
383
- sim = util.cos_sim(ea, eb).item()
384
- marbert_sim = (sim + 1)/2
385
  return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
386
  "semantic_score": round(max(sbert_sim, marbert_sim),3)}
387
 
388
  # =========================
389
- # Audio input helper (filepath or numpy)
390
  # =========================
391
- import soundfile as sf
392
-
393
  def ensure_audio_path(audio):
394
- """
395
- Accepts:
396
- - str (filepath)
397
- - tuple (numpy_array, sample_rate) if Gradio Audio type='numpy'
398
- Returns a filepath suitable for faster-whisper.
399
- """
400
  if isinstance(audio, str):
 
 
401
  return audio
402
  if isinstance(audio, tuple) and len(audio) == 2:
403
  data, sr = audio
@@ -408,79 +415,87 @@ def ensure_audio_path(audio):
408
  raise ValueError("Unsupported audio input format")
409
 
410
  # =========================
411
- # Transcribe + Evaluate
412
  # =========================
413
- def transcribe_and_evaluate(audio, original_text, whisper_size="small",
414
- compute_type=("float16" if DEVICE=="cuda" else "int8"),
415
- vad=True, use_marbert=True):
416
- # Load models lazily
417
- load_models(whisper_name=whisper_size, whisper_compute=compute_type)
418
-
419
- # Transcribe (word_timestamps=True for word-level probs)
420
- audio_path = ensure_audio_path(audio)
421
- segments, info = _WHISPER.transcribe(
422
- audio_path, word_timestamps=True,
423
- vad_filter=vad, vad_parameters={"min_silence_duration_ms": 200}
424
- )
425
- segments = list(segments)
 
426
 
427
- # ASR text from words (cleaned)
428
- words = []
429
- for seg in segments:
430
- for w in (seg.words or []):
431
- tok = clean_ar_token(w.word)
432
- if tok: words.append(tok)
433
- asr_text = " ".join(words)
434
-
435
- # Tokens + align
436
- ref_tokens = simple_tokenize(original_text)
437
- hyp_tokens = simple_tokenize(asr_text)
438
- aligned = align_texts(ref_tokens, hyp_tokens)
439
-
440
- # Word confidence map
441
- df_words = extract_word_conf_table(segments)
442
- asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
443
-
444
- # Classify + corrected text
445
- results, corrected_text = classify_alignment_optimized(
446
- aligned, ref_tokens, hyp_tokens,
447
- bert_thresh=0.75, max_bert=0.85,
448
- asr_token_conf=asr_token_conf, low_high=(low_t, high_t)
449
- )
450
 
451
- # Scores
452
- lit = literal_similarity(original_text, corrected_text)
453
- if use_marbert:
454
- sem = semantic_similarity(original_text, corrected_text)
455
- else:
456
- sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original_text, convert_to_tensor=True),
457
- _SBERT.encode(corrected_text, convert_to_tensor=True)))
458
- sem = {"sbert_sim": round(sbert_sim,3), "semantic_score": round(sbert_sim,3)}
459
-
460
- df = pd.DataFrame(results)
461
-
462
- report = {
463
- "whisper_model": whisper_size,
464
- "compute_type": compute_type,
465
- "original_text": original_text,
466
- "asr_text": asr_text,
467
- "corrected_text": corrected_text,
468
- "literal": lit,
469
- "semantic": sem,
470
- "low_t": low_t, "high_t": high_t,
471
- }
472
- return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df
 
 
 
 
473
 
474
- # =========================
475
- # JSON-only API wrapper (optional)
476
- # =========================
477
- def api_predict(audio, original_text, whisper_size="small",
478
- compute_type=("float16" if DEVICE=="cuda" else "int8"),
479
- vad=True, use_marbert=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  corrected_text, asr_text, report_json, df = transcribe_and_evaluate(
481
  audio, original_text, whisper_size, compute_type, vad, use_marbert
482
  )
483
- return json.loads(report_json)
 
 
 
484
 
485
  # =========================
486
  # Gradio UI
@@ -488,23 +503,25 @@ def api_predict(audio, original_text, whisper_size="small",
488
  def build_ui():
489
  with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo:
490
  gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)")
 
491
  with gr.Row():
 
492
  audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
493
  original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
494
 
495
  with gr.Row():
496
  whisper_size = gr.Dropdown(
497
  choices=["tiny","base","small","medium","large-v3"],
498
- value=("large-v3" if DEVICE=="cuda" else "small"),
499
  label="Whisper model size"
500
  )
501
  compute_type = gr.Dropdown(
502
  choices=["int8", "int8_float16", "float16", "float32"],
503
- value=("float16" if DEVICE=="cuda" else "int8"),
504
  label="compute_type"
505
  )
506
  vad = gr.Checkbox(value=True, label="VAD filter")
507
- use_marbert = gr.Checkbox(value=(DEVICE=="cuda"), label="Use MARBERT (semantic)")
508
 
509
  btn = gr.Button("Transcribe & Evaluate", variant="primary")
510
 
@@ -515,20 +532,18 @@ def build_ui():
515
  table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"],
516
  label="Token-level Decisions", wrap=True)
517
 
518
- # UI action + API endpoint
519
  btn.click(
520
  fn=transcribe_and_evaluate,
521
  inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
522
  outputs=[corrected, asr_out, report, table],
523
- api_name="evaluate" # ← Inference API endpoint
524
  )
525
 
526
- # JSON-only endpoint (hidden button)
527
  gr.Button(visible=False).click(
528
  fn=api_predict,
529
  inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
530
  outputs=gr.JSON(),
531
- api_name="predict" # ← Inference API endpoint (JSON only)
532
  )
533
 
534
  return demo
 
1
+ import os, re, json, math, tempfile, traceback
2
  import numpy as np
3
  import pandas as pd
4
  import torch
 
9
 
10
  from sentence_transformers import SentenceTransformer, util
11
  from transformers import AutoTokenizer, AutoModel
12
+ import soundfile as sf
13
 
14
  # =========================
15
+ # Device & global config
16
  # =========================
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ CPU_MODE = (DEVICE != "cuda")
19
+
20
+ # أمان الذاكرة على CPU
21
+ DEFAULT_WHISPER_CPU = "small"
22
+ DEFAULT_COMPUTE_CPU = "int8"
23
+ DEFAULT_USE_MARBERT_CPU = False
24
 
25
+ # =========================
26
+ # Lazy models
27
+ # =========================
28
  _SBERT = None
29
  _MARBERT_TOK = None
30
  _MARBERT = None
 
36
  whisper_name="small",
37
  whisper_compute="int8"
38
  ):
39
+ """Load models only once."""
40
  global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER
41
+
42
+ # حماية على CPU: اجبار نماذج أخف
43
+ if CPU_MODE:
44
+ whisper_name = DEFAULT_WHISPER_CPU
45
+ whisper_compute = DEFAULT_COMPUTE_CPU
46
+
47
  if _SBERT is None:
48
  _SBERT = SentenceTransformer(sbert_name, device=DEVICE)
49
+
50
+ # حمّل MARBERT فقط عند الحاجة (قد يستهلك RAM)
51
+ if _MARBERT is None and (not CPU_MODE):
52
  _MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
53
  _MARBERT = AutoModel.from_pretrained(marbert_name).to(DEVICE)
54
  _MARBERT.eval()
55
+
56
  if _WHISPER is None:
57
  _WHISPER = WhisperModel(whisper_name, device=DEVICE, compute_type=whisper_compute)
58
 
 
68
  return text
69
 
70
  def simple_tokenize(text: str):
71
+ """يحاول punkt؛ وإن فشل يستخدم تجزئة بسيطة بالمسافات."""
72
+ t = normalize_ar_orth(text)
73
  try:
74
+ import nltk
75
+ try:
76
+ nltk.data.find('tokenizers/punkt')
77
+ except LookupError:
78
+ nltk.download('punkt', quiet=True)
79
+ return nltk.word_tokenize(t)
80
+ except Exception:
81
+ return t.split()
82
 
83
  def align_texts(ref_tokens, hyp_tokens):
84
  import difflib
 
120
  return textdistance.levenshtein(w1, w2) == 1
121
 
122
  # =========================
123
+ # Numbers
124
  # =========================
125
  AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")
126
  UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2,
 
161
  return words_to_number(toks)
162
 
163
  # =========================
164
+ # Semantic similarities
165
  # =========================
166
  def marbert_cls_similarity(a: str, b: str) -> float:
167
  if not a or not b: return 0.0
168
+ if _MARBERT is None:
169
+ return 0.0
170
  with torch.no_grad():
171
  ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
172
  tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
 
258
  return base_decision
259
 
260
  # =========================
261
+ # Pair + main classifiers
262
  # =========================
263
  def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
264
  bert_thresh=0.75, max_bert=0.85):
 
265
  ref_num = to_numeric_value(ref_w)
266
  hyp_num = to_numeric_value(hyp_w)
267
  if (ref_num is not None) or (hyp_num is not None):
268
  if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
269
  return 'ASR error (numbers equal)'
 
270
  if short_word and lev1:
271
  return 'ASR error (short+lev1)'
 
272
  avg_ok = bert_scores["avg"] >= bert_thresh
273
  max_ok = bert_scores["max"] > max_bert
274
  if ((phon_sim or lev1) and avg_ok) or max_ok:
 
278
  def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
279
  bert_thresh=0.75, max_bert=0.85,
280
  asr_token_conf=None, low_high=None):
 
281
  if low_high is None:
282
  if asr_token_conf:
283
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
 
291
  else:
292
  low_t, high_t = low_high
293
 
294
+ results, corrected_words = [], []
 
295
 
296
  for entry in aligned:
297
  tag = entry['type']
 
307
  for k in range(max_len):
308
  ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
309
  hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
310
+ if not ref_w and not hyp_w:
311
+ continue
312
 
313
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
314
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
 
325
  else:
326
  base_status = 'Undefined Case'
327
 
 
328
  word_prob = None; word_dur = None
329
  if (j1 is not None) and (j2 is not None):
330
  hyp_abs_idx = j1 + k
 
342
  low_t=low_t, high_t=high_t, sbert_lo=0.60
343
  )
344
 
345
+ used = hyp_w
346
  if ref_w and hyp_w:
347
  used = ref_w if final_status.startswith("ASR error") else hyp_w
348
  elif hyp_w == '':
349
  used = ''
350
  elif ref_w == '':
351
  used = hyp_w
 
 
352
 
353
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
354
  f'SBERT={bert_scores["sbert"]:.2f}, '
 
368
  return results, corrected_text
369
 
370
  # =========================
371
+ # Scores
372
  # =========================
373
  def literal_similarity(original, recited):
 
 
 
 
 
 
374
  def norm(t):
375
  t = re.sub(r'[ًٌٍَُِّْـ]', '', t)
376
  t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t)
 
378
  return t
379
  o = norm(original); r = norm(recited)
380
  lev = textdistance.levenshtein.normalized_similarity(o, r)
381
+ ot = simple_tokenize(o); rt = simple_tokenize(r)
382
  common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
383
  word_overlap = common / max(len(ot), 1)
384
+ try:
385
+ import nltk.translate.bleu_score as bleu
386
+ bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
387
+ except Exception:
388
+ bleu1 = 0.0
389
  final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
390
  return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
391
  "bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
392
 
393
+ def semantic_similarity(original, recited, use_marbert=True):
394
  sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True),
395
  _SBERT.encode(recited, convert_to_tensor=True)))
396
+ marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0
 
 
 
 
 
 
397
  return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
398
  "semantic_score": round(max(sbert_sim, marbert_sim),3)}
399
 
400
  # =========================
401
+ # Audio input helper
402
  # =========================
 
 
403
  def ensure_audio_path(audio):
404
+ """Accepts filepath (str) OR (numpy_array, sr). Returns a valid filepath."""
 
 
 
 
 
405
  if isinstance(audio, str):
406
+ if not os.path.exists(audio):
407
+ raise FileNotFoundError(f"Audio path not found: {audio}")
408
  return audio
409
  if isinstance(audio, tuple) and len(audio) == 2:
410
  data, sr = audio
 
415
  raise ValueError("Unsupported audio input format")
416
 
417
  # =========================
418
+ # Pipeline (with robust error reporting)
419
  # =========================
420
+ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
421
+ compute_type=None, vad=True, use_marbert=True):
422
+ try:
423
+ if not original_text or not original_text.strip():
424
+ raise ValueError("Original text is empty.")
425
+
426
+ # Defaults per device
427
+ if CPU_MODE:
428
+ whisper_size = DEFAULT_WHISPER_CPU
429
+ compute_type = DEFAULT_COMPUTE_CPU
430
+ use_marbert = DEFAULT_USE_MARBERT_CPU
431
+ else:
432
+ whisper_size = whisper_size or "large-v3"
433
+ compute_type = compute_type or "float16"
434
 
435
+ load_models(whisper_name=whisper_size, whisper_compute=compute_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
 
437
+ audio_path = ensure_audio_path(audio)
438
+ segments, info = _WHISPER.transcribe(
439
+ audio_path, word_timestamps=True,
440
+ vad_filter=vad, vad_parameters={"min_silence_duration_ms": 200}
441
+ )
442
+ segments = list(segments)
443
+
444
+ words = []
445
+ for seg in segments:
446
+ for w in (seg.words or []):
447
+ tok = clean_ar_token(w.word)
448
+ if tok: words.append(tok)
449
+ asr_text = " ".join(words)
450
+
451
+ ref_tokens = simple_tokenize(original_text)
452
+ hyp_tokens = simple_tokenize(asr_text)
453
+ aligned = align_texts(ref_tokens, hyp_tokens)
454
+
455
+ df_words = extract_word_conf_table(segments)
456
+ asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
457
+
458
+ results, corrected_text = classify_alignment_optimized(
459
+ aligned, ref_tokens, hyp_tokens,
460
+ bert_thresh=0.75, max_bert=0.85,
461
+ asr_token_conf=asr_token_conf, low_high=(low_t, high_t)
462
+ )
463
 
464
+ lit = literal_similarity(original_text, corrected_text)
465
+ sem = semantic_similarity(original_text, corrected_text, use_marbert=(use_marbert and not CPU_MODE))
466
+
467
+ df = pd.DataFrame(results)
468
+
469
+ report = {
470
+ "whisper_model": whisper_size,
471
+ "compute_type": compute_type,
472
+ "original_text": original_text,
473
+ "asr_text": asr_text,
474
+ "corrected_text": corrected_text,
475
+ "literal": lit,
476
+ "semantic": sem,
477
+ "low_t": low_t, "high_t": high_t,
478
+ }
479
+ return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df
480
+
481
+ except Exception as e:
482
+ tb = traceback.format_exc()
483
+ print("ERROR in transcribe_and_evaluate:\n", tb, flush=True)
484
+ # نرجع JSON بالخطأ بدل ما نفجّر الواجهة
485
+ empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}])
486
+ err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2)
487
+ gr.Warning(str(e))
488
+ return "", "", err_json, empty_df
489
+
490
+ def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True):
491
+ # نفس الدالة لكن ترجع JSON فقط
492
  corrected_text, asr_text, report_json, df = transcribe_and_evaluate(
493
  audio, original_text, whisper_size, compute_type, vad, use_marbert
494
  )
495
+ try:
496
+ return json.loads(report_json)
497
+ except Exception:
498
+ return {"error": "Failed to parse report_json."}
499
 
500
  # =========================
501
  # Gradio UI
 
503
  def build_ui():
504
  with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo:
505
  gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)")
506
+
507
  with gr.Row():
508
+ # filepath أسلم للـ Spaces
509
  audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
510
  original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
511
 
512
  with gr.Row():
513
  whisper_size = gr.Dropdown(
514
  choices=["tiny","base","small","medium","large-v3"],
515
+ value=("large-v3" if not CPU_MODE else DEFAULT_WHISPER_CPU),
516
  label="Whisper model size"
517
  )
518
  compute_type = gr.Dropdown(
519
  choices=["int8", "int8_float16", "float16", "float32"],
520
+ value=("float16" if not CPU_MODE else DEFAULT_COMPUTE_CPU),
521
  label="compute_type"
522
  )
523
  vad = gr.Checkbox(value=True, label="VAD filter")
524
+ use_marbert = gr.Checkbox(value=(not CPU_MODE), label="Use MARBERT (semantic)")
525
 
526
  btn = gr.Button("Transcribe & Evaluate", variant="primary")
527
 
 
532
  table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"],
533
  label="Token-level Decisions", wrap=True)
534
 
 
535
  btn.click(
536
  fn=transcribe_and_evaluate,
537
  inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
538
  outputs=[corrected, asr_out, report, table],
539
+ api_name="evaluate"
540
  )
541
 
 
542
  gr.Button(visible=False).click(
543
  fn=api_predict,
544
  inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
545
  outputs=gr.JSON(),
546
+ api_name="predict"
547
  )
548
 
549
  return demo