MuhammadHijazii commited on
Commit
5d519e9
ยท
verified ยท
1 Parent(s): 4fffe95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -128
app.py CHANGED
@@ -14,19 +14,28 @@ import soundfile as sf
14
  # =========================
15
  # Global config (forced per your request)
16
  # =========================
17
- # ู†ุซุจู‘ุช ุงู„ุฅุนุฏุงุฏุงุช ุงู„ู…ุทู„ูˆุจุฉ ุนู„ู‰ CPU
18
  FORCE_WHISPER_NAME = "large-v3"
19
  FORCE_COMPUTE_TYPE = "int8"
20
  FORCE_USE_MARBERT = True
21
 
22
- # ุฎูŠุงุฑุงุช ุชูุฑูŠุบ ุซุงุจุชุฉ ู„ุชู‚ู„ูŠู„ ุงู„ูุฑูˆู‚ุงุช ู…ุน ุงู„ู†ูˆุชุจูˆูƒ
 
 
 
 
 
 
 
 
 
 
23
  ASR_OPTS = dict(
24
  word_timestamps=True,
25
  vad_filter=True,
26
  vad_parameters={"min_silence_duration_ms": 200},
27
  beam_size=5,
28
  best_of=5,
29
- temperature=0.0, # ุฌุนู„ ููƒ ุงู„ุชุดููŠุฑ ุญุชู…ูŠ ู‚ุฏุฑ ุงู„ุฅู…ูƒุงู†
30
  )
31
 
32
  # =========================
@@ -57,7 +66,6 @@ def load_models(
57
  _SBERT = SentenceTransformer(sbert_name, device=("cuda" if DEVICE=="cuda" else "cpu"))
58
  print(f"[LOAD] SBERT: {sbert_name}", flush=True)
59
 
60
- # ู…ูุนู‘ู„ ุนู„ู‰ CPU ุญุณุจ ุฑุบุจุชูƒ
61
  if _MARBERT is None and use_marbert:
62
  _MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
63
  _MARBERT = AutoModel.from_pretrained(marbert_name).to(("cuda" if DEVICE=="cuda" else "cpu"))
@@ -73,11 +81,23 @@ def load_models(
73
  # Normalization / Tokenization / Alignment
74
  # =========================
75
  def normalize_ar_orth(text: str) -> str:
 
76
  text = re.sub(r"[ู‘ูŽู‹ููŒููู’ู€]", "", text)
77
  text = re.sub(r"[โ€œโ€\"',:ุ›ุŸ.!()\[\]{}ุŒ\-โ€“โ€”_]", " ", text)
 
 
78
  text = re.sub(r"\s+", " ", text).strip()
79
  return text
80
 
 
 
 
 
 
 
 
 
 
81
  def simple_tokenize(text: str):
82
  t = normalize_ar_orth(text)
83
  try:
@@ -118,8 +138,7 @@ def arabic_soundex(word):
118
  for ch in w:
119
  for rep, chars in groups.items():
120
  if ch in chars:
121
- code.append(rep)
122
- break
123
  return "".join(code)
124
 
125
  def phonetic_similarity(w1, w2):
@@ -172,28 +191,60 @@ def to_numeric_value(token: str):
172
  return words_to_number(toks)
173
 
174
  # =========================
175
- # Semantic similarities
176
  # =========================
 
 
 
 
 
 
177
  def marbert_cls_similarity(a: str, b: str) -> float:
178
- if not a or not b: return 0.0
179
- if _MARBERT is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  return 0.0
 
181
  with torch.no_grad():
182
- ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
183
- tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
184
- ea = _MARBERT(**ta).last_hidden_state[:,0,:]
185
- eb = _MARBERT(**tb).last_hidden_state[:,0,:]
186
- sim = util.cos_sim(ea, eb).item()
187
- return (sim + 1) / 2
188
 
189
  def multi_bert_similarity(a: str, b: str):
190
  if not a or not b:
191
- return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0}
192
- sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(a, convert_to_tensor=True),
193
- _SBERT.encode(b, convert_to_tensor=True)))
194
- marbert_sim = marbert_cls_similarity(a, b)
 
 
 
 
 
 
 
 
 
195
  vals = [sbert_sim, marbert_sim]
196
- return {"sbert": sbert_sim, "marbert": marbert_sim, "max": max(vals), "avg": sum(vals)/len(vals)}
 
197
 
198
  # =========================
199
  # Faster-Whisper helpers
@@ -269,38 +320,48 @@ def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float,
269
  return base_decision
270
 
271
  # =========================
272
- # Pair + main classifiers
273
  # =========================
274
  def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
275
  bert_thresh=0.75, max_bert=0.85):
 
276
  ref_num = to_numeric_value(ref_w)
277
  hyp_num = to_numeric_value(hyp_w)
278
  if (ref_num is not None) or (hyp_num is not None):
279
  if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
280
  return 'ASR error (numbers equal)'
 
 
281
  if short_word and lev1:
282
  return 'ASR error (short+lev1)'
283
- avg_ok = bert_scores["avg"] >= bert_thresh
284
- max_ok = bert_scores["max"] > max_bert
285
- if ((phon_sim or lev1) and avg_ok) or max_ok:
286
- return 'ASR error (semantic/phonetic)'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  return 'Memorization error'
288
 
289
  def classify_alignment_optimized(
290
  aligned, ref_tokens, hyp_tokens,
291
  bert_thresh=0.75, max_bert=0.85,
292
  asr_token_conf=None, low_high=None,
293
- replace_budget_tokens=None, # NEW: ุณู‚ู ุงู„ุงุณุชุจุฏุงู„ (int ุฃูˆ None)
294
- guard_note=None # NEW: ูˆุณู… ุญุฑ (ู…ุซู„ุงู‹: "off-topic" ุฃูˆ "ok")
295
  ):
296
- """
297
- ู…ุตู†ู‘ู ุงู„ู…ุญุงุฐุงุฉ ู…ุน ุฏุนู… 'ุณู‚ู ุงู„ุงุณุชุจุฏุงู„'.
298
- - ุฅุฐุง replace_budget_tokens=None โ†’ ู„ุง ูŠูˆุฌุฏ ุณู‚ู.
299
- - ุฅุฐุง replace_budget_tokens=0 โ†’ ู„ุง ูŠุชู… ุฃูŠ ุงุณุชุจุฏุงู„ ุญุชู‰ ู„ูˆ ูƒุงู†ุช ุงู„ุญุงู„ุฉ ASR error.
300
- - ุนู†ุฏ ุจู„ูˆุบ ุงู„ุณู‚ู ู†ุญุชูุธ ุจูƒู„ู…ุฉ ุงู„ุทุงู„ุจ ูˆู†ุถูŠู "[guard: budget reached]" ุนู„ู‰ ุงู„ุญุงู„ุฉ.
301
- - guard_note (ุงุฎุชูŠุงุฑูŠ) ูŠูุถุงู ู„ู„ู€ reason ู„ุชูˆุซูŠู‚ ู‚ุฑุงุฑ ุงู„ุญุงุฑุณ ุงู„ุนุงู„ู…ูŠ.
302
- """
303
- # --- thresholds ู…ู† ุงุญุชู…ุงู„ุงุช ุงู„ูƒู„ู…ุงุช ---
304
  if low_high is None:
305
  if asr_token_conf:
306
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
@@ -315,7 +376,7 @@ def classify_alignment_optimized(
315
  low_t, high_t = low_high
316
 
317
  results, corrected_words = [], []
318
- replaced_count = 0 # NEW: ุนุฏู‘ุงุฏ ุงู„ุงุณุชุจุฏุงู„ุงุช ุงู„ูุนู„ูŠุฉ
319
 
320
  for entry in aligned:
321
  tag = entry['type']
@@ -324,7 +385,7 @@ def classify_alignment_optimized(
324
 
325
  if tag == 'equal':
326
  for ref_w, hyp_w in zip(entry['ref'], entry['hyp']):
327
- results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': ''})
328
  corrected_words.append(hyp_w)
329
 
330
  elif tag in ['replace', 'delete', 'insert']:
@@ -335,13 +396,13 @@ def classify_alignment_optimized(
335
  if not ref_w and not hyp_w:
336
  continue
337
 
338
- # --- similarities ---
339
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
340
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
341
  bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0}
342
  short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6)
343
 
344
- # --- base status ---
345
  if ref_w and hyp_w:
346
  base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
347
  bert_thresh, max_bert)
@@ -352,7 +413,7 @@ def classify_alignment_optimized(
352
  else:
353
  base_status = 'Undefined Case'
354
 
355
- # --- word-level confidence gate ---
356
  word_prob = None; word_dur = None
357
  if (j1 is not None) and (j2 is not None):
358
  hyp_abs_idx = j1 + k
@@ -370,30 +431,27 @@ def classify_alignment_optimized(
370
  low_t=low_t, high_t=high_t, sbert_lo=0.60
371
  )
372
 
373
- # --- choose token to use (with budget) ---
374
  used = hyp_w
375
  budget_info = ""
376
  if ref_w and hyp_w:
377
  if final_status.startswith("ASR error"):
378
- # ู†ุชุญู‚ู‚ ู…ู† ุงู„ุณู‚ู
379
  if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens):
380
  used = ref_w
381
  replaced_count += 1
382
  if replace_budget_tokens is not None:
383
  budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
384
  else:
385
- # ุชุฌุงูˆุฒ ุงู„ุณู‚ู โ†’ ู„ุง ู†ุณุชุจุฏู„
386
  used = hyp_w
387
  final_status += " [guard: budget reached]"
388
  budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
389
  else:
390
  used = hyp_w
391
  elif hyp_w == '':
392
- used = '' # ุญุฐู
393
  elif ref_w == '':
394
- used = hyp_w # ุฅุฏุฑุงุฌ
395
 
396
- # --- reason string ---
397
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
398
  f'SBERT={bert_scores["sbert"]:.2f}, '
399
  f'MARBERT={bert_scores["marbert"]:.2f}, '
@@ -403,6 +461,8 @@ def classify_alignment_optimized(
403
  f'dur_ms={None if word_dur is None else int(word_dur)}, '
404
  f'low_t={round(low_t,2)}, high_t={round(high_t,2)}')
405
 
 
 
406
  if guard_note:
407
  reason += f", guard='{guard_note}'"
408
  if budget_info:
@@ -416,62 +476,26 @@ def classify_alignment_optimized(
416
  corrected_words.append(used)
417
 
418
  corrected_text = " ".join([w for w in corrected_words if w])
419
- return results, corrected_text
420
-
421
- # =========================
422
- # Scores
423
- # =========================
424
- def literal_similarity(original, recited):
425
- def norm(t):
426
- t = re.sub(r'[ู‘ูŽู‹ููŒููู’ู€]', '', t)
427
- t = re.sub(r'[โ€œโ€",:ุ›ุŸ.!()\[\]{}ุŒ\-โ€“โ€”_]', ' ', t)
428
- t = re.sub(r'\s+', ' ', t).strip()
429
- return t
430
- o = norm(original); r = norm(recited)
431
- lev = textdistance.levenshtein.normalized_similarity(o, r)
432
- ot = simple_tokenize(o); rt = simple_tokenize(r)
433
- common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
434
- word_overlap = common / max(len(ot), 1)
435
- try:
436
- import nltk.translate.bleu_score as bleu
437
- bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
438
- except Exception:
439
- bleu1 = 0.0
440
- final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
441
- return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
442
- "bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
443
 
444
- def semantic_similarity(original, recited, use_marbert=FORCE_USE_MARBERT):
445
- sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True),
446
- _SBERT.encode(recited, convert_to_tensor=True)))
447
- marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0
448
- return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
449
- "semantic_score": round(max(sbert_sim, marbert_sim),3)}
450
-
451
- # =========================
452
- # Audio helper
453
- # =========================
454
- def ensure_audio_path(audio):
455
- if isinstance(audio, str):
456
- if not os.path.exists(audio):
457
- raise FileNotFoundError(f"Audio path not found: {audio}")
458
- return audio
459
- if isinstance(audio, tuple) and len(audio) == 2:
460
- data, sr = audio
461
- if isinstance(data, np.ndarray):
462
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
463
- sf.write(tmp.name, data, sr)
464
- return tmp.name
465
- raise ValueError("Unsupported audio input format")
466
 
 
467
 
468
  # =========================
469
- #
470
  # =========================
471
-
472
-
473
  def lcs_len(a, b):
474
- """Longest Common Subsequence length ุนู„ู‰ ู…ุณุชูˆู‰ ุงู„ุชูˆูƒู†ุงุช."""
475
  m, n = len(a), len(b)
476
  dp = [[0]*(n+1) for _ in range(m+1)]
477
  for i in range(1, m+1):
@@ -484,7 +508,6 @@ def lcs_len(a, b):
484
  return dp[m][n]
485
 
486
  def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2):
487
- """ุชู‚ุฑูŠุจ ROUGE-L F1 ุนู„ู‰ ู…ุณุชูˆู‰ ุงู„ุชูˆูƒู†ุงุช."""
488
  if not ref_tokens or not hyp_tokens:
489
  return 0.0, 0.0, 0.0
490
  lcs = lcs_len(ref_tokens, hyp_tokens)
@@ -496,7 +519,6 @@ def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2):
496
  return float(f1), float(prec), float(rec)
497
 
498
  def compute_wer_like(aligned, ref_tokens_len):
499
- """WER ู…ุจุณุท ู…ู† opcodes: (S+D+I)/N."""
500
  S = D = I = 0
501
  for op in aligned:
502
  if op['type'] == 'replace':
@@ -509,36 +531,18 @@ def compute_wer_like(aligned, ref_tokens_len):
509
  return (S + D + I) / N
510
 
511
  def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model):
512
- """
513
- ูŠุนูŠุฏ dict ูŠุญูˆูŠ:
514
- off_topic: bool
515
- budget_tokens: int (ุณู‚ู ุงู„ุงุณุชุจุฏุงู„ุงุช ุงู„ู…ุณู…ูˆุญ)
516
- metrics: ูƒู„ ุงู„ู…ู‚ุงูŠูŠุณ ู„ู„ุชู‚ุฑูŠุฑ
517
- """
518
- # SBERT ู„ู„ู†ุต ุงู„ูƒุงู…ู„
519
  sbert_sim_text = float(util.pytorch_cos_sim(
520
- sbert_model.encode(original_text, convert_to_tensor=True),
521
- sbert_model.encode(asr_text, convert_to_tensor=True)
522
  ))
523
 
524
- # ROUGE-L(F1) ูˆ LCS ุจู†ุณุฎุฉ ุชูˆูƒู†ุงุช
525
  rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens)
526
-
527
- # ู†ุณุจุฉ ุงู„ุชุทุงุจู‚ ุงู„ู…ุจุงุดุฑ (equal) ู…ู† ุงู„ู…ุญุงุฐุงุฉ
528
  equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal')
529
  equal_ratio = equal_tokens / max(len(ref_tokens), 1)
530
-
531
- # WER ู…ุจุณู‘ุท
532
  wer = compute_wer_like(aligned, len(ref_tokens))
533
 
534
- # ู‚ุงุนุฏุฉ ู‚ุฑุงุฑ Off-topic (ุญุฐุฑูŠู†)
535
- # ู†ุนุชุจุฑ ุฎุงุฑุฌ ุงู„ู†ุต ุฅุฐุง: SBERT<0.70 ูˆ ROUGE_F1<0.45 ูˆ equal_ratio<0.25 ุฃูˆ WER>0.65
536
  off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65))
537
 
538
- # ู…ูŠุฒุงู†ูŠุฉ ุงู„ุงุณุชุจุฏุงู„ (ุนุฏุฏ ุงู„ูƒู„ู…ุงุช ูƒุญุฏ ุฃู‚ุตู‰ ูŠูุณู…ุญ ุจุงุณุชุจุฏุงู„ู‡ุง ุจู€ GT)
539
- # - ุฎุงุฑุฌ ุงู„ู†ุต: 0
540
- # - ุชุดุงุจู‡ ู…ุชูˆุณุท: 15% ู…ู† ุทูˆู„ Hyp
541
- # - ุชุดุงุจู‡ ู…ุฑุชูุน: 40% ู…ู† ุทูˆู„ Hyp
542
  L = len(hyp_tokens)
543
  if off_topic:
544
  budget = 0
@@ -558,18 +562,64 @@ def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, align
558
  print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True)
559
  return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics}
560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
 
562
  # =========================
563
- # Pipeline (robust errors + logs)
564
  # =========================
 
 
 
 
 
 
 
 
 
 
 
 
565
 
 
 
 
566
  def transcribe_and_evaluate(audio, original_text, whisper_size=None,
567
  compute_type=None, vad=True, use_marbert=True):
568
  try:
569
  if not original_text or not original_text.strip():
570
  raise ValueError("Original text is empty.")
571
 
572
- # ู†ูู‡ู…ู„ ุงุฎุชูŠุงุฑุงุช ุงู„ูˆุงุฌู‡ุฉ ูˆู†ูุฑุถ ุฅุนุฏุงุฏุงุชูƒ
573
  whisper_size = FORCE_WHISPER_NAME
574
  compute_type = FORCE_COMPUTE_TYPE
575
  use_marbert = FORCE_USE_MARBERT
@@ -585,7 +635,7 @@ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
585
  segments = list(segments)
586
  print(f"[ASR] segments={len(segments)}", flush=True)
587
 
588
- # Build ASR text from words (more control)
589
  words = []
590
  for seg in segments:
591
  for w in (seg.words or []):
@@ -594,37 +644,66 @@ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
594
  words.append(tok)
595
  asr_text = " ".join(words)
596
 
 
597
  ref_tokens = simple_tokenize(original_text)
598
  hyp_tokens = simple_tokenize(asr_text)
599
  aligned = align_texts(ref_tokens, hyp_tokens)
600
 
601
- # --- Global guard ---
602
  guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT)
603
  off_topic = guard["off_topic"]
604
- budget_tokens = guard["budget_tokens"]
605
  guard_metrics = guard["metrics"]
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  df_words = extract_word_conf_table(segments)
608
  asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
609
  print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True)
610
 
611
- results, corrected_text = classify_alignment_optimized(
 
612
  aligned, ref_tokens, hyp_tokens,
613
  bert_thresh=0.75, max_bert=0.85,
614
  asr_token_conf=asr_token_conf, low_high=(low_t, high_t),
615
- replace_budget_tokens=budget_tokens, # โ† ุนุฏุฏ ุงุณุชุจุฏุงู„ุงุช ุฃู‚ุตู‰
616
- guard_note=("off-topic" if off_topic else "ok")
617
  )
618
 
 
619
  lit = literal_similarity(original_text, corrected_text)
620
  sem = semantic_similarity(original_text, corrected_text, use_marbert=use_marbert)
621
 
 
 
 
 
 
 
 
 
 
622
  df = pd.DataFrame(results)
623
 
624
  report = {
625
  "requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
626
  "effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
627
- "guard": {"off_topic": off_topic,"budget_tokens": int(budget_tokens),**guard_metrics},
 
 
628
  "original_text": original_text,
629
  "asr_text": asr_text,
630
  "corrected_text": corrected_text,
@@ -663,7 +742,6 @@ def build_ui():
663
  original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
664
 
665
  with gr.Row():
666
- # ูˆุงุฌู‡ุฉ ุซุงุจุชุฉ ุญุณุจ ุทู„ุจูƒ (ุชูู‡ู…ู„ ููŠ ุงู„ุฏุงู„ุฉ ู„ูƒู† ู†ุนุฑุถู‡ุง)
667
  whisper_size = gr.Dropdown(choices=["large-v3"], value="large-v3", label="Whisper model size (forced)")
668
  compute_type = gr.Dropdown(choices=["int8"], value="int8", label="compute_type (forced)")
669
  vad = gr.Checkbox(value=True, label="VAD filter")
 
14
  # =========================
15
  # Global config (forced per your request)
16
  # =========================
 
17
  FORCE_WHISPER_NAME = "large-v3"
18
  FORCE_COMPUTE_TYPE = "int8"
19
  FORCE_USE_MARBERT = True
20
 
21
+ # ======= Budget Config =======
22
+ # "auto": ูŠุนุชู…ุฏ ุนู„ู‰ ุงู„ุญุงุฑุณ ุงู„ุนุงู„ู…ูŠ (SBERT/ROUGE/WER)
23
+ # "fixed": ุนุฏุฏ ุซุงุจุช ู…ู† ุงู„ุงุณุชุจุฏุงู„ุงุช (0 ูŠุนู†ูŠ ุนุฏู… ุงุณุชุจุฏุงู„ ู…ุทู„ู‚ู‹ุง)
24
+ # "ratio": ู†ุณุจุฉ ู…ู† ุทูˆู„ ุงู„ู†ุต ุงู„ู…ู†ุทูˆู‚
25
+ # "off": ุจุฏูˆู† ุณู‚ู (ุณู„ูˆูƒ ู‚ุฏูŠู…)
26
+ FORCE_BUDGET_MODE = "auto" # "auto" | "fixed" | "ratio" | "off"
27
+ FIXED_BUDGET_TOKENS = 0
28
+ BUDGET_RATIO = 0.15
29
+ # =============================
30
+
31
+ # ุฎูŠุงุฑุงุช ุชูุฑูŠุบ ุซุงุจุชุฉ ู„ุชู‚ู„ูŠู„ ุงู„ูุฑูˆู‚ุงุช
32
  ASR_OPTS = dict(
33
  word_timestamps=True,
34
  vad_filter=True,
35
  vad_parameters={"min_silence_duration_ms": 200},
36
  beam_size=5,
37
  best_of=5,
38
+ temperature=0.0,
39
  )
40
 
41
  # =========================
 
66
  _SBERT = SentenceTransformer(sbert_name, device=("cuda" if DEVICE=="cuda" else "cpu"))
67
  print(f"[LOAD] SBERT: {sbert_name}", flush=True)
68
 
 
69
  if _MARBERT is None and use_marbert:
70
  _MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
71
  _MARBERT = AutoModel.from_pretrained(marbert_name).to(("cuda" if DEVICE=="cuda" else "cpu"))
 
81
  # Normalization / Tokenization / Alignment
82
  # =========================
83
  def normalize_ar_orth(text: str) -> str:
84
+ # ุชุทุจูŠุน ุนุงู… ู„ู„ู…ุญุงุฐุงุฉ
85
  text = re.sub(r"[ู‘ูŽู‹ููŒููู’ู€]", "", text)
86
  text = re.sub(r"[โ€œโ€\"',:ุ›ุŸ.!()\[\]{}ุŒ\-โ€“โ€”_]", " ", text)
87
+ text = re.sub(r"[ุฅุฃูฑุขุง]", "ุง", text)
88
+ text = text.replace("ุฉ", "ู‡").replace("ู‰", "ูŠ")
89
  text = re.sub(r"\s+", " ", text).strip()
90
  return text
91
 
92
+ def _normalize_for_models(s: str) -> str:
93
+ # ุชุทุจูŠุน ุฎุงุต ู„ู…ุฏุฎู„ุงุช SBERT/MARBERT
94
+ s = re.sub(r"[ู‘ูŽู‹ููŒููู’ู€]", "", s)
95
+ s = re.sub(r"[โ€œโ€\"',:ุ›ุŸ.!()\[\]{}ุŒ\-โ€“โ€”_]", " ", s)
96
+ s = re.sub(r"[ุฅุฃูฑุขุง]", "ุง", s)
97
+ s = s.replace("ุฉ", "ู‡").replace("ู‰", "ูŠ")
98
+ s = re.sub(r"\s+", " ", s).strip()
99
+ return s
100
+
101
  def simple_tokenize(text: str):
102
  t = normalize_ar_orth(text)
103
  try:
 
138
  for ch in w:
139
  for rep, chars in groups.items():
140
  if ch in chars:
141
+ code.append(rep); break
 
142
  return "".join(code)
143
 
144
  def phonetic_similarity(w1, w2):
 
191
  return words_to_number(toks)
192
 
193
  # =========================
194
+ # Semantic similarities (MARBERT fixed)
195
  # =========================
196
+ def _mean_pool(last_hidden_state, attention_mask):
197
+ mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
198
+ summed = (last_hidden_state * mask).sum(dim=1)
199
+ counts = mask.sum(dim=1).clamp(min=1e-9)
200
+ return summed / counts
201
+
202
  def marbert_cls_similarity(a: str, b: str) -> float:
203
+ """Return 0 when [UNK] dominates; use mean pooling instead of CLS only."""
204
+ if not a or not b or _MARBERT is None:
205
+ return 0.0
206
+
207
+ a_n = _normalize_for_models(a)
208
+ b_n = _normalize_for_models(b)
209
+
210
+ # UNK ratio check
211
+ ids_a = _MARBERT_TOK(a_n, add_special_tokens=False).input_ids
212
+ ids_b = _MARBERT_TOK(b_n, add_special_tokens=False).input_ids
213
+ unk_id = _MARBERT_TOK.unk_token_id
214
+ if len(ids_a) == 0 or len(ids_b) == 0:
215
+ return 0.0
216
+ unk_ratio_a = (ids_a.count(unk_id) / len(ids_a)) if unk_id is not None else 0.0
217
+ unk_ratio_b = (ids_b.count(unk_id) / len(ids_b)) if unk_id is not None else 0.0
218
+ if max(unk_ratio_a, unk_ratio_b) > 0.5:
219
+ # too many unknowns โ†’ ignore MARBERT
220
  return 0.0
221
+
222
  with torch.no_grad():
223
+ ta = _MARBERT_TOK(a_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
224
+ tb = _MARBERT_TOK(b_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
225
+ ea = _mean_pool(_MARBERT(**ta).last_hidden_state, ta["attention_mask"])
226
+ eb = _mean_pool(_MARBERT(**tb).last_hidden_state, tb["attention_mask"])
227
+ sim = util.cos_sim(ea, eb).item() # -1..1
228
+ return (sim + 1) / 2 # 0..1
229
 
230
  def multi_bert_similarity(a: str, b: str):
231
  if not a or not b:
232
+ return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0, "note":"empty"}
233
+
234
+ a_n = _normalize_for_models(a); b_n = _normalize_for_models(b)
235
+ sbert_sim = float(util.pytorch_cos_sim(
236
+ _SBERT.encode(a_n, convert_to_tensor=True),
237
+ _SBERT.encode(b_n, convert_to_tensor=True)
238
+ ))
239
+ marbert_sim = marbert_cls_similarity(a_n, b_n)
240
+
241
+ note = None
242
+ if abs(sbert_sim - marbert_sim) > 0.35:
243
+ note = "models_disagree"
244
+
245
  vals = [sbert_sim, marbert_sim]
246
+ return {"sbert": sbert_sim, "marbert": marbert_sim,
247
+ "max": max(vals), "avg": sum(vals)/len(vals), "note": note}
248
 
249
  # =========================
250
  # Faster-Whisper helpers
 
320
  return base_decision
321
 
322
  # =========================
323
+ # Pair + main classifiers (tightened)
324
  # =========================
325
  def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
326
  bert_thresh=0.75, max_bert=0.85):
327
+ # numbers equal
328
  ref_num = to_numeric_value(ref_w)
329
  hyp_num = to_numeric_value(hyp_w)
330
  if (ref_num is not None) or (hyp_num is not None):
331
  if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
332
  return 'ASR error (numbers equal)'
333
+
334
+ # short+lev1
335
  if short_word and lev1:
336
  return 'ASR error (short+lev1)'
337
+
338
+ # semantic/phonetic
339
+ sbert_ok = bert_scores["sbert"] >= 0.70
340
+ avg_ok = bert_scores["avg"] >= bert_thresh
341
+ max_ok = (bert_scores["max"] > max_bert) and sbert_ok
342
+ disagree = (bert_scores.get("note") == "models_disagree")
343
+
344
+ if not disagree:
345
+ if ((phon_sim or lev1) and avg_ok) or max_ok:
346
+ return 'ASR error (semantic/phonetic)'
347
+ else:
348
+ if phon_sim or lev1:
349
+ if sbert_ok and avg_ok:
350
+ return 'ASR error (semantic/phonetic)'
351
+ else:
352
+ if bert_scores["sbert"] >= 0.78:
353
+ return 'ASR error (semantic)'
354
+
355
  return 'Memorization error'
356
 
357
  def classify_alignment_optimized(
358
  aligned, ref_tokens, hyp_tokens,
359
  bert_thresh=0.75, max_bert=0.85,
360
  asr_token_conf=None, low_high=None,
361
+ replace_budget_tokens=None, # ุณู‚ู ุงู„ุงุณุชุจุฏุงู„
362
+ guard_note=None # ูˆุณู… ู…ุซู„ "off-topic"/"ok"/"budget_off"
363
  ):
364
+ # thresholds ู…ู† ุงุญุชู…ุงู„ุงุช ุงู„ูƒู„ู…ุงุช
 
 
 
 
 
 
 
365
  if low_high is None:
366
  if asr_token_conf:
367
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
 
376
  low_t, high_t = low_high
377
 
378
  results, corrected_words = [], []
379
+ replaced_count = 0
380
 
381
  for entry in aligned:
382
  tag = entry['type']
 
385
 
386
  if tag == 'equal':
387
  for ref_w, hyp_w in zip(entry['ref'], entry['hyp']):
388
+ results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': '', 'used': hyp_w})
389
  corrected_words.append(hyp_w)
390
 
391
  elif tag in ['replace', 'delete', 'insert']:
 
396
  if not ref_w and not hyp_w:
397
  continue
398
 
399
+ # similarities
400
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
401
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
402
  bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0}
403
  short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6)
404
 
405
+ # base status
406
  if ref_w and hyp_w:
407
  base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
408
  bert_thresh, max_bert)
 
413
  else:
414
  base_status = 'Undefined Case'
415
 
416
+ # word-level confidence gate
417
  word_prob = None; word_dur = None
418
  if (j1 is not None) and (j2 is not None):
419
  hyp_abs_idx = j1 + k
 
431
  low_t=low_t, high_t=high_t, sbert_lo=0.60
432
  )
433
 
434
+ # choose token with budget
435
  used = hyp_w
436
  budget_info = ""
437
  if ref_w and hyp_w:
438
  if final_status.startswith("ASR error"):
 
439
  if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens):
440
  used = ref_w
441
  replaced_count += 1
442
  if replace_budget_tokens is not None:
443
  budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
444
  else:
 
445
  used = hyp_w
446
  final_status += " [guard: budget reached]"
447
  budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
448
  else:
449
  used = hyp_w
450
  elif hyp_w == '':
451
+ used = ''
452
  elif ref_w == '':
453
+ used = hyp_w
454
 
 
455
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
456
  f'SBERT={bert_scores["sbert"]:.2f}, '
457
  f'MARBERT={bert_scores["marbert"]:.2f}, '
 
461
  f'dur_ms={None if word_dur is None else int(word_dur)}, '
462
  f'low_t={round(low_t,2)}, high_t={round(high_t,2)}')
463
 
464
+ if bert_scores.get("note"):
465
+ reason += f", note={bert_scores['note']}"
466
  if guard_note:
467
  reason += f", guard='{guard_note}'"
468
  if budget_info:
 
476
  corrected_words.append(used)
477
 
478
  corrected_text = " ".join([w for w in corrected_words if w])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
+ # ุฅุญุตุงุกุงุช ู…ุญู„ูŠุฉ ู…ููŠุฏุฉ ู„ู„ุชู‚ุฑูŠุฑ
481
+ stats = {
482
+ "replacements_made": sum(1 for r in results
483
+ if r.get("used") and r.get("GT_word") and r["used"] == r["GT_word"]
484
+ and r.get("ASR_word") and r["ASR_word"] != r["GT_word"]),
485
+ "budget_reached_count": sum(1 for r in results if isinstance(r.get("status"), str) and "budget reached" in r["status"]),
486
+ "asr_error_count": sum(1 for r in results if isinstance(r.get("status"), str) and r["status"].startswith("ASR error")),
487
+ "memorization_error_count": sum(1 for r in results if r.get("status") == "Memorization error"),
488
+ "missing_count": sum(1 for r in results if r.get("status","").startswith("Missing")),
489
+ "extra_count": sum(1 for r in results if r.get("status","").startswith("Extra")),
490
+ "total_tokens": len(results)
491
+ }
 
 
 
 
 
 
 
 
 
 
492
 
493
+ return results, corrected_text, stats
494
 
495
  # =========================
496
+ # ROUGE-L / WER-like / Guard
497
  # =========================
 
 
498
  def lcs_len(a, b):
 
499
  m, n = len(a), len(b)
500
  dp = [[0]*(n+1) for _ in range(m+1)]
501
  for i in range(1, m+1):
 
508
  return dp[m][n]
509
 
510
  def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2):
 
511
  if not ref_tokens or not hyp_tokens:
512
  return 0.0, 0.0, 0.0
513
  lcs = lcs_len(ref_tokens, hyp_tokens)
 
519
  return float(f1), float(prec), float(rec)
520
 
521
  def compute_wer_like(aligned, ref_tokens_len):
 
522
  S = D = I = 0
523
  for op in aligned:
524
  if op['type'] == 'replace':
 
531
  return (S + D + I) / N
532
 
533
  def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model):
 
 
 
 
 
 
 
534
  sbert_sim_text = float(util.pytorch_cos_sim(
535
+ sbert_model.encode(_normalize_for_models(original_text), convert_to_tensor=True),
536
+ sbert_model.encode(_normalize_for_models(asr_text), convert_to_tensor=True)
537
  ))
538
 
 
539
  rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens)
 
 
540
  equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal')
541
  equal_ratio = equal_tokens / max(len(ref_tokens), 1)
 
 
542
  wer = compute_wer_like(aligned, len(ref_tokens))
543
 
 
 
544
  off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65))
545
 
 
 
 
 
546
  L = len(hyp_tokens)
547
  if off_topic:
548
  budget = 0
 
562
  print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True)
563
  return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics}
564
 
565
+ # =========================
566
+ # Scores
567
+ # =========================
568
+ def literal_similarity(original, recited):
569
+ def norm(t):
570
+ t = re.sub(r'[ู‘ูŽู‹ููŒููู’ู€]', '', t)
571
+ t = re.sub(r'[โ€œโ€",:ุ›ุŸ.!()\[\]{}ุŒ\-โ€“โ€”_]', ' ', t)
572
+ t = re.sub(r'\s+', ' ', t).strip()
573
+ return t
574
+ o = norm(original); r = norm(recited)
575
+ lev = textdistance.levenshtein.normalized_similarity(o, r)
576
+ ot = simple_tokenize(o); rt = simple_tokenize(r)
577
+ common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
578
+ word_overlap = common / max(len(ot), 1)
579
+ try:
580
+ import nltk.translate.bleu_score as bleu
581
+ bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
582
+ except Exception:
583
+ bleu1 = 0.0
584
+ final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
585
+ return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
586
+ "bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
587
+
588
+ def semantic_similarity(original, recited, use_marbert=FORCE_USE_MARBERT):
589
+ sbert_sim = float(util.pytorch_cos_sim(
590
+ _SBERT.encode(_normalize_for_models(original), convert_to_tensor=True),
591
+ _SBERT.encode(_normalize_for_models(recited), convert_to_tensor=True)
592
+ ))
593
+ marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0
594
+ return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
595
+ "semantic_score": round(max(sbert_sim, marbert_sim),3)}
596
 
597
  # =========================
598
+ # Audio helper
599
  # =========================
600
+ def ensure_audio_path(audio):
601
+ if isinstance(audio, str):
602
+ if not os.path.exists(audio):
603
+ raise FileNotFoundError(f"Audio path not found: {audio}")
604
+ return audio
605
+ if isinstance(audio, tuple) and len(audio) == 2:
606
+ data, sr = audio
607
+ if isinstance(data, np.ndarray):
608
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
609
+ sf.write(tmp.name, data, sr)
610
+ return tmp.name
611
+ raise ValueError("Unsupported audio input format")
612
 
613
+ # =========================
614
+ # Pipeline (robust errors + logs)
615
+ # =========================
616
  def transcribe_and_evaluate(audio, original_text, whisper_size=None,
617
  compute_type=None, vad=True, use_marbert=True):
618
  try:
619
  if not original_text or not original_text.strip():
620
  raise ValueError("Original text is empty.")
621
 
622
+ # Forced settings
623
  whisper_size = FORCE_WHISPER_NAME
624
  compute_type = FORCE_COMPUTE_TYPE
625
  use_marbert = FORCE_USE_MARBERT
 
635
  segments = list(segments)
636
  print(f"[ASR] segments={len(segments)}", flush=True)
637
 
638
+ # Build ASR text from words
639
  words = []
640
  for seg in segments:
641
  for w in (seg.words or []):
 
644
  words.append(tok)
645
  asr_text = " ".join(words)
646
 
647
+ # Tokens & alignment
648
  ref_tokens = simple_tokenize(original_text)
649
  hyp_tokens = simple_tokenize(asr_text)
650
  aligned = align_texts(ref_tokens, hyp_tokens)
651
 
652
+ # Guard & budget
653
  guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT)
654
  off_topic = guard["off_topic"]
 
655
  guard_metrics = guard["metrics"]
656
 
657
+ if FORCE_BUDGET_MODE == "off":
658
+ budget_tokens = None
659
+ guard_note = "budget_off"
660
+ elif FORCE_BUDGET_MODE == "fixed":
661
+ budget_tokens = int(FIXED_BUDGET_TOKENS)
662
+ guard_note = f"budget_fixed_{budget_tokens}"
663
+ elif FORCE_BUDGET_MODE == "ratio":
664
+ budget_tokens = int(BUDGET_RATIO * len(hyp_tokens))
665
+ guard_note = f"budget_ratio_{BUDGET_RATIO}"
666
+ else:
667
+ budget_tokens = guard["budget_tokens"]
668
+ guard_note = "off-topic" if off_topic else "ok"
669
+
670
+ print(f"[BUDGET] mode={FORCE_BUDGET_MODE}, budget={budget_tokens}, note={guard_note}", flush=True)
671
+
672
+ # Word-level confidences
673
  df_words = extract_word_conf_table(segments)
674
  asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
675
  print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True)
676
 
677
+ # Classification
678
+ results, corrected_text, local_stats = classify_alignment_optimized(
679
  aligned, ref_tokens, hyp_tokens,
680
  bert_thresh=0.75, max_bert=0.85,
681
  asr_token_conf=asr_token_conf, low_high=(low_t, high_t),
682
+ replace_budget_tokens=budget_tokens,
683
+ guard_note=guard_note
684
  )
685
 
686
+ # Scores
687
  lit = literal_similarity(original_text, corrected_text)
688
  sem = semantic_similarity(original_text, corrected_text, use_marbert=use_marbert)
689
 
690
+ # Extra global metrics for report
691
+ all_probs = df_words["prob"].dropna().tolist()
692
+ conf_summary = {
693
+ "num_words_with_prob": int(len(all_probs)),
694
+ "avg_prob": None if not all_probs else float(np.mean(all_probs)),
695
+ "p15": None if not all_probs else float(np.quantile(all_probs, 0.15)),
696
+ "p70": None if not all_probs else float(np.quantile(all_probs, 0.70)),
697
+ }
698
+
699
  df = pd.DataFrame(results)
700
 
701
  report = {
702
  "requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
703
  "effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
704
+ "guard": {"mode": FORCE_BUDGET_MODE, "off_topic": off_topic, "budget_tokens": None if budget_tokens is None else int(budget_tokens), **guard_metrics},
705
+ "local_stats": local_stats,
706
+ "confidence_summary": conf_summary,
707
  "original_text": original_text,
708
  "asr_text": asr_text,
709
  "corrected_text": corrected_text,
 
742
  original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
743
 
744
  with gr.Row():
 
745
  whisper_size = gr.Dropdown(choices=["large-v3"], value="large-v3", label="Whisper model size (forced)")
746
  compute_type = gr.Dropdown(choices=["int8"], value="int8", label="compute_type (forced)")
747
  vad = gr.Checkbox(value=True, label="VAD filter")