MuhammadHijazii commited on
Commit
4fffe95
·
verified ·
1 Parent(s): e4b6c5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -12
app.py CHANGED
@@ -286,9 +286,21 @@ def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
286
  return 'ASR error (semantic/phonetic)'
287
  return 'Memorization error'
288
 
289
- def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
290
- bert_thresh=0.75, max_bert=0.85,
291
- asr_token_conf=None, low_high=None):
 
 
 
 
 
 
 
 
 
 
 
 
292
  if low_high is None:
293
  if asr_token_conf:
294
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
@@ -303,29 +315,33 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
303
  low_t, high_t = low_high
304
 
305
  results, corrected_words = [], []
 
306
 
307
  for entry in aligned:
308
  tag = entry['type']
309
- i1, i2 = entry.get('ref_idx', (None,None))
310
- j1, j2 = entry.get('hyp_idx', (None,None))
311
 
312
  if tag == 'equal':
313
  for ref_w, hyp_w in zip(entry['ref'], entry['hyp']):
314
  results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': ''})
315
  corrected_words.append(hyp_w)
 
316
  elif tag in ['replace', 'delete', 'insert']:
317
  max_len = max(len(entry['ref']), len(entry['hyp']))
318
  for k in range(max_len):
319
  ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
320
  hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
321
- if not ref_w and not hyp_w:
322
  continue
323
 
 
324
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
325
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
326
  bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0}
327
  short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6)
328
 
 
329
  if ref_w and hyp_w:
330
  base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
331
  bert_thresh, max_bert)
@@ -336,6 +352,7 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
336
  else:
337
  base_status = 'Undefined Case'
338
 
 
339
  word_prob = None; word_dur = None
340
  if (j1 is not None) and (j2 is not None):
341
  hyp_abs_idx = j1 + k
@@ -353,14 +370,30 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
353
  low_t=low_t, high_t=high_t, sbert_lo=0.60
354
  )
355
 
 
356
  used = hyp_w
 
357
  if ref_w and hyp_w:
358
- used = ref_w if final_status.startswith("ASR error") else hyp_w
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  elif hyp_w == '':
360
- used = ''
361
  elif ref_w == '':
362
- used = hyp_w
363
 
 
364
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
365
  f'SBERT={bert_scores["sbert"]:.2f}, '
366
  f'MARBERT={bert_scores["marbert"]:.2f}, '
@@ -370,8 +403,15 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
370
  f'dur_ms={None if word_dur is None else int(word_dur)}, '
371
  f'low_t={round(low_t,2)}, high_t={round(high_t,2)}')
372
 
373
- results.append({'ASR_word': hyp_w, 'GT_word': ref_w,
374
- 'status': final_status, 'reason': reason, 'used': used})
 
 
 
 
 
 
 
375
  if used:
376
  corrected_words.append(used)
377
 
@@ -424,9 +464,105 @@ def ensure_audio_path(audio):
424
  return tmp.name
425
  raise ValueError("Unsupported audio input format")
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  # =========================
428
  # Pipeline (robust errors + logs)
429
  # =========================
 
430
  def transcribe_and_evaluate(audio, original_text, whisper_size=None,
431
  compute_type=None, vad=True, use_marbert=True):
432
  try:
@@ -462,6 +598,12 @@ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
462
  hyp_tokens = simple_tokenize(asr_text)
463
  aligned = align_texts(ref_tokens, hyp_tokens)
464
 
 
 
 
 
 
 
465
  df_words = extract_word_conf_table(segments)
466
  asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
467
  print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True)
@@ -469,7 +611,9 @@ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
469
  results, corrected_text = classify_alignment_optimized(
470
  aligned, ref_tokens, hyp_tokens,
471
  bert_thresh=0.75, max_bert=0.85,
472
- asr_token_conf=asr_token_conf, low_high=(low_t, high_t)
 
 
473
  )
474
 
475
  lit = literal_similarity(original_text, corrected_text)
@@ -480,6 +624,7 @@ def transcribe_and_evaluate(audio, original_text, whisper_size=None,
480
  report = {
481
  "requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
482
  "effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
 
483
  "original_text": original_text,
484
  "asr_text": asr_text,
485
  "corrected_text": corrected_text,
 
286
  return 'ASR error (semantic/phonetic)'
287
  return 'Memorization error'
288
 
289
+ def classify_alignment_optimized(
290
+ aligned, ref_tokens, hyp_tokens,
291
+ bert_thresh=0.75, max_bert=0.85,
292
+ asr_token_conf=None, low_high=None,
293
+ replace_budget_tokens=None, # NEW: سقف الاستبدال (int أو None)
294
+ guard_note=None # NEW: وسم حر (مثلاً: "off-topic" أو "ok")
295
+ ):
296
+ """
297
+ مصنّف المحاذاة مع دعم 'سقف الاستبدال'.
298
+ - إذا replace_budget_tokens=None → لا يوجد سقف.
299
+ - إذا replace_budget_tokens=0 → لا يتم أي استبدال حتى لو كانت الحالة ASR error.
300
+ - عند بلوغ السقف نحتفظ بكلمة الطالب ونضيف "[guard: budget reached]" على الحالة.
301
+ - guard_note (اختياري) يُضاف للـ reason لتوثيق قرار الحارس العالمي.
302
+ """
303
+ # --- thresholds من احتمالات الكلمات ---
304
  if low_high is None:
305
  if asr_token_conf:
306
  probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
 
315
  low_t, high_t = low_high
316
 
317
  results, corrected_words = [], []
318
+ replaced_count = 0 # NEW: عدّاد الاستبدالات الفعلية
319
 
320
  for entry in aligned:
321
  tag = entry['type']
322
+ i1, i2 = entry.get('ref_idx', (None, None))
323
+ j1, j2 = entry.get('hyp_idx', (None, None))
324
 
325
  if tag == 'equal':
326
  for ref_w, hyp_w in zip(entry['ref'], entry['hyp']):
327
  results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': ''})
328
  corrected_words.append(hyp_w)
329
+
330
  elif tag in ['replace', 'delete', 'insert']:
331
  max_len = max(len(entry['ref']), len(entry['hyp']))
332
  for k in range(max_len):
333
  ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
334
  hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
335
+ if not ref_w and not hyp_w:
336
  continue
337
 
338
+ # --- similarities ---
339
  phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
340
  lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
341
  bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0}
342
  short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6)
343
 
344
+ # --- base status ---
345
  if ref_w and hyp_w:
346
  base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
347
  bert_thresh, max_bert)
 
352
  else:
353
  base_status = 'Undefined Case'
354
 
355
+ # --- word-level confidence gate ---
356
  word_prob = None; word_dur = None
357
  if (j1 is not None) and (j2 is not None):
358
  hyp_abs_idx = j1 + k
 
370
  low_t=low_t, high_t=high_t, sbert_lo=0.60
371
  )
372
 
373
+ # --- choose token to use (with budget) ---
374
  used = hyp_w
375
+ budget_info = ""
376
  if ref_w and hyp_w:
377
+ if final_status.startswith("ASR error"):
378
+ # نتحقق من السقف
379
+ if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens):
380
+ used = ref_w
381
+ replaced_count += 1
382
+ if replace_budget_tokens is not None:
383
+ budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
384
+ else:
385
+ # تجاوز السقف → لا نستبدل
386
+ used = hyp_w
387
+ final_status += " [guard: budget reached]"
388
+ budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
389
+ else:
390
+ used = hyp_w
391
  elif hyp_w == '':
392
+ used = '' # حذف
393
  elif ref_w == '':
394
+ used = hyp_w # إدراج
395
 
396
+ # --- reason string ---
397
  reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
398
  f'SBERT={bert_scores["sbert"]:.2f}, '
399
  f'MARBERT={bert_scores["marbert"]:.2f}, '
 
403
  f'dur_ms={None if word_dur is None else int(word_dur)}, '
404
  f'low_t={round(low_t,2)}, high_t={round(high_t,2)}')
405
 
406
+ if guard_note:
407
+ reason += f", guard='{guard_note}'"
408
+ if budget_info:
409
+ reason += budget_info
410
+
411
+ results.append({
412
+ 'ASR_word': hyp_w, 'GT_word': ref_w,
413
+ 'status': final_status, 'reason': reason, 'used': used
414
+ })
415
  if used:
416
  corrected_words.append(used)
417
 
 
464
  return tmp.name
465
  raise ValueError("Unsupported audio input format")
466
 
467
+
468
+ # =========================
469
+ #
470
+ # =========================
471
+
472
+
473
+ def lcs_len(a, b):
474
+ """Longest Common Subsequence length على مستوى التوكنات."""
475
+ m, n = len(a), len(b)
476
+ dp = [[0]*(n+1) for _ in range(m+1)]
477
+ for i in range(1, m+1):
478
+ ai = a[i-1]
479
+ for j in range(1, n+1):
480
+ if ai == b[j-1]:
481
+ dp[i][j] = dp[i-1][j-1] + 1
482
+ else:
483
+ dp[i][j] = dp[i-1][j] if dp[i-1][j] >= dp[i][j-1] else dp[i][j-1]
484
+ return dp[m][n]
485
+
486
+ def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2):
487
+ """تقريب ROUGE-L F1 على مستوى التوكنات."""
488
+ if not ref_tokens or not hyp_tokens:
489
+ return 0.0, 0.0, 0.0
490
+ lcs = lcs_len(ref_tokens, hyp_tokens)
491
+ prec = lcs / len(hyp_tokens)
492
+ rec = lcs / len(ref_tokens)
493
+ if prec == 0 and rec == 0:
494
+ return 0.0, 0.0, 0.0
495
+ f1 = ((1+beta**2) * prec * rec) / (rec + beta**2 * prec + 1e-12)
496
+ return float(f1), float(prec), float(rec)
497
+
498
+ def compute_wer_like(aligned, ref_tokens_len):
499
+ """WER مبسط من opcodes: (S+D+I)/N."""
500
+ S = D = I = 0
501
+ for op in aligned:
502
+ if op['type'] == 'replace':
503
+ S += max(len(op['ref']), len(op['hyp']))
504
+ elif op['type'] == 'delete':
505
+ D += len(op['ref'])
506
+ elif op['type'] == 'insert':
507
+ I += len(op['hyp'])
508
+ N = max(ref_tokens_len, 1)
509
+ return (S + D + I) / N
510
+
511
+ def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model):
512
+ """
513
+ يعيد dict يحوي:
514
+ off_topic: bool
515
+ budget_tokens: int (سقف الاستبدالات المسموح)
516
+ metrics: كل المقاييس للتقرير
517
+ """
518
+ # SBERT للنص الكامل
519
+ sbert_sim_text = float(util.pytorch_cos_sim(
520
+ sbert_model.encode(original_text, convert_to_tensor=True),
521
+ sbert_model.encode(asr_text, convert_to_tensor=True)
522
+ ))
523
+
524
+ # ROUGE-L(F1) و LCS بنسخة توكنات
525
+ rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens)
526
+
527
+ # نسبة التطابق المباشر (equal) من المحاذاة
528
+ equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal')
529
+ equal_ratio = equal_tokens / max(len(ref_tokens), 1)
530
+
531
+ # WER مبسّط
532
+ wer = compute_wer_like(aligned, len(ref_tokens))
533
+
534
+ # قاعدة قرار Off-topic (حذرين)
535
+ # نعتبر خارج النص إذا: SBERT<0.70 و ROUGE_F1<0.45 و equal_ratio<0.25 أو WER>0.65
536
+ off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65))
537
+
538
+ # ميزانية الاستبدال (عدد الكلمات كحد أقصى يُسمح باستبدالها بـ GT)
539
+ # - خارج النص: 0
540
+ # - تشابه متوسط: 15% من طول Hyp
541
+ # - تشابه مرتفع: 40% من طول Hyp
542
+ L = len(hyp_tokens)
543
+ if off_topic:
544
+ budget = 0
545
+ elif sbert_sim_text < 0.80 or rouge_f1 < 0.55:
546
+ budget = int(0.15 * L)
547
+ else:
548
+ budget = int(0.40 * L)
549
+
550
+ metrics = {
551
+ "sbert_sim_text": round(sbert_sim_text, 3),
552
+ "rougeL_f1": round(rouge_f1, 3),
553
+ "rougeL_prec": round(rouge_p, 3),
554
+ "rougeL_rec": round(rouge_r, 3),
555
+ "equal_ratio": round(equal_ratio, 3),
556
+ "wer_like": round(wer, 3),
557
+ }
558
+ print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True)
559
+ return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics}
560
+
561
+
562
  # =========================
563
  # Pipeline (robust errors + logs)
564
  # =========================
565
+
566
  def transcribe_and_evaluate(audio, original_text, whisper_size=None,
567
  compute_type=None, vad=True, use_marbert=True):
568
  try:
 
598
  hyp_tokens = simple_tokenize(asr_text)
599
  aligned = align_texts(ref_tokens, hyp_tokens)
600
 
601
+ # --- Global guard ---
602
+ guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT)
603
+ off_topic = guard["off_topic"]
604
+ budget_tokens = guard["budget_tokens"]
605
+ guard_metrics = guard["metrics"]
606
+
607
  df_words = extract_word_conf_table(segments)
608
  asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
609
  print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True)
 
611
  results, corrected_text = classify_alignment_optimized(
612
  aligned, ref_tokens, hyp_tokens,
613
  bert_thresh=0.75, max_bert=0.85,
614
+ asr_token_conf=asr_token_conf, low_high=(low_t, high_t),
615
+ replace_budget_tokens=budget_tokens, # ← عدد استبدالات أقصى
616
+ guard_note=("off-topic" if off_topic else "ok")
617
  )
618
 
619
  lit = literal_similarity(original_text, corrected_text)
 
624
  report = {
625
  "requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
626
  "effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
627
+ "guard": {"off_topic": off_topic,"budget_tokens": int(budget_tokens),**guard_metrics},
628
  "original_text": original_text,
629
  "asr_text": asr_text,
630
  "corrected_text": corrected_text,