import os, re, json, math, tempfile, traceback import numpy as np import pandas as pd import torch import textdistance import gradio as gr from faster_whisper import WhisperModel from sentence_transformers import SentenceTransformer, util from transformers import AutoTokenizer, AutoModel import soundfile as sf # ========================= # Global config (forced per your request) # ========================= FORCE_WHISPER_NAME = "large-v3" FORCE_COMPUTE_TYPE = "int8" FORCE_USE_MARBERT = True # ======= Budget Config ======= # "auto": يعتمد على الحارس العالمي (SBERT/ROUGE/WER) # "fixed": عدد ثابت من الاستبدالات (0 يعني عدم استبدال مطلقًا) # "ratio": نسبة من طول النص المنطوق # "off": بدون سقف (سلوك قديم) FORCE_BUDGET_MODE = "auto" # "auto" | "fixed" | "ratio" | "off" FIXED_BUDGET_TOKENS = 0 BUDGET_RATIO = 0.15 # ============================= # خيارات تفريغ ثابتة لتقليل الفروقات ASR_OPTS = dict( word_timestamps=True, vad_filter=True, vad_parameters={"min_silence_duration_ms": 200}, beam_size=5, best_of=5, temperature=0.0, ) # ========================= # Device # ========================= DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"[INIT] DEVICE={DEVICE}", flush=True) # ========================= # Lazy models # ========================= _SBERT = None _MARBERT_TOK = None _MARBERT = None _WHISPER = None def load_models( sbert_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", marbert_name="UBC-NLP/MARBERT", whisper_name=FORCE_WHISPER_NAME, whisper_compute=FORCE_COMPUTE_TYPE, use_marbert=FORCE_USE_MARBERT ): """Load models once; forced config respected even on CPU.""" global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER if _SBERT is None: _SBERT = SentenceTransformer(sbert_name, device=("cuda" if DEVICE=="cuda" else "cpu")) print(f"[LOAD] SBERT: {sbert_name}", flush=True) if _MARBERT is None and use_marbert: _MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name) _MARBERT = AutoModel.from_pretrained(marbert_name).to(("cuda" if DEVICE=="cuda" else "cpu")) _MARBERT.eval() print(f"[LOAD] MARBERT: {marbert_name} (device={DEVICE})", flush=True) if _WHISPER is None: _WHISPER = WhisperModel(whisper_name, device=("cuda" if DEVICE=="cuda" else "cpu"), compute_type=whisper_compute) print(f"[LOAD] Whisper: {whisper_name} (compute={whisper_compute})", flush=True) # ========================= # Normalization / Tokenization / Alignment # ========================= def normalize_ar_orth(text: str) -> str: # تطبيع عام للمحاذاة text = re.sub(r"[ًٌٍَُِّْـ]", "", text) text = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", text) text = re.sub(r"\s+", " ", text).strip() return text def _normalize_for_models(s: str) -> str: # تطبيع خاص لمدخلات SBERT/MARBERT s = re.sub(r"[ًٌٍَُِّْـ]", "", s) s = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", s) s = re.sub(r"\s+", " ", s).strip() return s def simple_tokenize(text: str): t = normalize_ar_orth(text) try: import nltk try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt', quiet=True) return nltk.word_tokenize(t) except Exception: return t.split() def align_texts(ref_tokens, hyp_tokens): import difflib sm = difflib.SequenceMatcher(None, ref_tokens, hyp_tokens) aligned = [] for tag, i1, i2, j1, j2 in sm.get_opcodes(): aligned.append({ 'type': tag, 'ref': ref_tokens[i1:i2], 'hyp': hyp_tokens[j1:j2], 'ref_idx': (i1, i2), 'hyp_idx': (j1, j2) }) return aligned # ========================= # Phonetic / Levenshtein # ========================= def arabic_soundex(word): w = normalize_ar_orth(word) groups = { 'b': 'بف', 'j': 'جشص', 'd': 'دض', 't': 'طت', 'q': 'قغ', 'k': 'كخ', 's': 'سصز', 'z': 'ثذظ', 'h': 'ح', 'a': 'ع', 'm': 'م', 'n': 'ن', 'l': 'ل', 'r': 'ر', 'w': 'و', 'y': 'ي' } code = [] for ch in w: for rep, chars in groups.items(): if ch in chars: code.append(rep); break return "".join(code) def phonetic_similarity(w1, w2): if not w1 or not w2: return False return arabic_soundex(w1) == arabic_soundex(w2) def is_levenshtein_1(w1, w2): if not w1 or not w2: return False return textdistance.levenshtein(w1, w2) == 1 # ========================= # Numbers # ========================= AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789") UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2, "ثلاث":3,"ثلاثه":3,"ثلاثة":3,"اربع":4,"اربعه":4,"أربع":4,"أربعه":4, "خمس":5,"خمسه":5,"ست":6,"سته":6,"سبع":7,"سبعه":7,"ثمان":8,"ثماني":8,"ثمانيه":8, "تسع":9,"تسعه":9} TENS = {"عشر":10,"عشرة":10,"عشره":10,"عشرون":20,"عشرين":20,"ثلاثون":30,"ثلاثين":30, "اربعون":40,"أربعون":40,"اربعين":40,"خمسون":50,"ستون":60,"سبعون":70,"ثمانون":80,"تسعون":90} HUND = {"مئه":100,"مئة":100,"مائه":100} SCALE = {"الف":1000,"ألف":1000,"آلاف":1000,"مليون":10**6,"مليار":10**9} def normalize_digits(s: str) -> str: return s.translate(AR_DIGITS) def words_to_number(tokens): total = 0; current = 0 for w in tokens: if w in UNITS: current += UNITS[w] elif w in TENS: current += TENS[w] elif w in HUND: current += HUND[w] elif w in SCALE: current = max(1, current) * SCALE[w] total += current; current = 0 elif w == "و": continue else: total += current; current = 0 total += current return total if total != 0 else None def to_numeric_value(token: str): if not token: return None t = normalize_ar_orth(token) d = normalize_digits(t) if re.fullmatch(r"\d+", d): return int(d) toks = t.split() return words_to_number(toks) # ========================= # Semantic similarities (MARBERT fixed) # ========================= def _mean_pool(last_hidden_state, attention_mask): mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() summed = (last_hidden_state * mask).sum(dim=1) counts = mask.sum(dim=1).clamp(min=1e-9) return summed / counts def marbert_cls_similarity(a: str, b: str) -> float: """Return 0 when [UNK] dominates; use mean pooling instead of CLS only.""" if not a or not b or _MARBERT is None: return 0.0 a_n = _normalize_for_models(a) b_n = _normalize_for_models(b) # UNK ratio check ids_a = _MARBERT_TOK(a_n, add_special_tokens=False).input_ids ids_b = _MARBERT_TOK(b_n, add_special_tokens=False).input_ids unk_id = _MARBERT_TOK.unk_token_id if len(ids_a) == 0 or len(ids_b) == 0: return 0.0 unk_ratio_a = (ids_a.count(unk_id) / len(ids_a)) if unk_id is not None else 0.0 unk_ratio_b = (ids_b.count(unk_id) / len(ids_b)) if unk_id is not None else 0.0 if max(unk_ratio_a, unk_ratio_b) > 0.5: # too many unknowns → ignore MARBERT return 0.0 with torch.no_grad(): ta = _MARBERT_TOK(a_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu")) tb = _MARBERT_TOK(b_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu")) ea = _mean_pool(_MARBERT(**ta).last_hidden_state, ta["attention_mask"]) eb = _mean_pool(_MARBERT(**tb).last_hidden_state, tb["attention_mask"]) sim = util.cos_sim(ea, eb).item() # -1..1 return (sim + 1) / 2 # 0..1 def multi_bert_similarity(a: str, b: str): if not a or not b: return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0, "note":"empty"} a_n = _normalize_for_models(a); b_n = _normalize_for_models(b) sbert_sim = float(util.pytorch_cos_sim( _SBERT.encode(a_n, convert_to_tensor=True), _SBERT.encode(b_n, convert_to_tensor=True) )) marbert_sim = marbert_cls_similarity(a_n, b_n) note = None if abs(sbert_sim - marbert_sim) > 0.35: note = "models_disagree" vals = [sbert_sim, marbert_sim] return {"sbert": sbert_sim, "marbert": marbert_sim, "max": max(vals), "avg": sum(vals)/len(vals), "note": note} # ========================= # Faster-Whisper helpers # ========================= def clean_ar_token(t: str) -> str: t = t.strip() t = re.sub(r'^[^\w\u0600-\u06FF]+|[^\w\u0600-\u06FF]+$', '', t) t = normalize_ar_orth(t) return t def extract_word_conf_table(segments): rows = [] for seg in segments: for w in (seg.words or []): rows.append({ "seg_start": float(seg.start), "seg_end": float(seg.end), "word_start": float(w.start), "word_end": float(w.end), "word": clean_ar_token(w.word), "prob": float(w.probability), }) return pd.DataFrame(rows) def build_asr_token_conf(df_words: pd.DataFrame, hyp_tokens: list): toks_probs, toks_durs = [], [] for _, row in df_words.iterrows(): prob = row["prob"] dur = (row["word_end"] - row["word_start"]) * 1000.0 toks_probs.append(prob) toks_durs.append(dur) L = len(hyp_tokens) if len(toks_probs) >= L: toks_probs = toks_probs[:L] toks_durs = toks_durs[:L] else: pad = L - len(toks_probs) toks_probs += [None]*pad toks_durs += [None]*pad arr = np.array([p for p in toks_probs if p is not None]) if arr.size: low_t = float(np.quantile(arr, 0.15)) high_t = float(np.quantile(arr, 0.70)) else: low_t, high_t = 0.5, 0.85 asr_token_conf = {i: {"prob": toks_probs[i], "duration_ms": toks_durs[i]} for i in range(L)} return asr_token_conf, low_t, high_t # ========================= # Confidence gate # ========================= def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float, is_short: bool, lev1: bool, duration_ms: float = None, low_t: float = 0.6, high_t: float = 0.9, sbert_lo=0.60): band = "mid" if prob is not None: if prob <= low_t: band = "low" elif prob >= high_t: band = "high" very_short = (duration_ms is not None and duration_ms < 120) if band == "low": if is_short and lev1: return 'ASR error (low p + short+lev1)' if very_short: return 'ASR error (low p + very short)' if sbert_sim >= sbert_lo: return 'ASR error (low p + semantic)' return 'ASR error (low p)' if band == "high": return base_decision return base_decision # ========================= # Pair + main classifiers (tightened) # ========================= def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, bert_thresh=0.75, max_bert=0.85): # numbers equal ref_num = to_numeric_value(ref_w) hyp_num = to_numeric_value(hyp_w) if (ref_num is not None) or (hyp_num is not None): if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num): return 'ASR error (numbers equal)' # short+lev1 if short_word and lev1: return 'ASR error (short+lev1)' # semantic/phonetic sbert_ok = bert_scores["sbert"] >= 0.80 avg_ok = bert_scores["avg"] >= bert_thresh max_ok = (bert_scores["max"] > max_bert) and sbert_ok disagree = (bert_scores.get("note") == "models_disagree") if not disagree: if ((phon_sim or lev1) and avg_ok) or max_ok: return 'ASR error (semantic/phonetic)' else: if phon_sim or lev1: if sbert_ok and avg_ok: return 'ASR error (semantic/phonetic)' else: if bert_scores["sbert"] >= 0.80: return 'ASR error (semantic)' return 'Memorization error' def classify_alignment_optimized( aligned, ref_tokens, hyp_tokens, bert_thresh=0.75, max_bert=0.85, asr_token_conf=None, low_high=None, replace_budget_tokens=None, # سقف الاستبدال guard_note=None # وسم مثل "off-topic"/"ok"/"budget_off" ): # thresholds من احتمالات الكلمات if low_high is None: if asr_token_conf: probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None] if probs: low_t = float(np.quantile(probs, 0.15)) high_t = float(np.quantile(probs, 0.70)) else: low_t, high_t = 0.5, 0.85 else: low_t, high_t = 0.5, 0.85 else: low_t, high_t = low_high results, corrected_words = [], [] replaced_count = 0 for entry in aligned: tag = entry['type'] i1, i2 = entry.get('ref_idx', (None, None)) j1, j2 = entry.get('hyp_idx', (None, None)) if tag == 'equal': for ref_w, hyp_w in zip(entry['ref'], entry['hyp']): results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': '', 'used': hyp_w}) corrected_words.append(hyp_w) elif tag in ['replace', 'delete', 'insert']: max_len = max(len(entry['ref']), len(entry['hyp'])) for k in range(max_len): ref_w = entry['ref'][k] if k < len(entry['ref']) else '' hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else '' if not ref_w and not hyp_w: continue # similarities phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0} short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6) # base status if ref_w and hyp_w: base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, bert_thresh, max_bert) elif hyp_w == '': base_status = 'Missing (possible omission)' elif ref_w == '': base_status = 'Extra (possible ASR insertion)' else: base_status = 'Undefined Case' # word-level confidence gate word_prob = None; word_dur = None if (j1 is not None) and (j2 is not None): hyp_abs_idx = j1 + k if asr_token_conf and hyp_abs_idx in asr_token_conf: word_prob = asr_token_conf[hyp_abs_idx].get("prob") word_dur = asr_token_conf[hyp_abs_idx].get("duration_ms") final_status = base_status if ref_w and hyp_w: final_status = gate_by_word_conf( base_decision=base_status, prob=word_prob, sbert_sim=bert_scores["sbert"], is_short=short_word, lev1=lev1, duration_ms=word_dur, low_t=low_t, high_t=high_t, sbert_lo=0.60 ) # choose token with budget used = hyp_w budget_info = "" if ref_w and hyp_w: if final_status.startswith("ASR error"): if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens): used = ref_w replaced_count += 1 if replace_budget_tokens is not None: budget_info = f", budget={replaced_count}/{replace_budget_tokens}" else: used = hyp_w final_status += " [guard: budget reached]" budget_info = f", budget={replaced_count}/{replace_budget_tokens}" else: used = hyp_w elif hyp_w == '': used = '' elif ref_w == '': used = hyp_w reason = (f'Phonetic={phon_sim}, Lev1={lev1}, ' f'SBERT={bert_scores["sbert"]:.2f}, ' f'MARBERT={bert_scores["marbert"]:.2f}, ' f'MAX={bert_scores["max"]:.2f}, ' f'AVG={bert_scores["avg"]:.2f}, short={short_word}, ' f'prob={None if word_prob is None else round(word_prob,2)}, ' f'dur_ms={None if word_dur is None else int(word_dur)}, ' f'low_t={round(low_t,2)}, high_t={round(high_t,2)}') if bert_scores.get("note"): reason += f", note={bert_scores['note']}" if guard_note: reason += f", guard='{guard_note}'" if budget_info: reason += budget_info results.append({ 'ASR_word': hyp_w, 'GT_word': ref_w, 'status': final_status, 'reason': reason, 'used': used }) if used: corrected_words.append(used) corrected_text = " ".join([w for w in corrected_words if w]) # إحصاءات محلية مفيدة للتقرير stats = { "replacements_made": sum(1 for r in results if r.get("used") and r.get("GT_word") and r["used"] == r["GT_word"] and r.get("ASR_word") and r["ASR_word"] != r["GT_word"]), "budget_reached_count": sum(1 for r in results if isinstance(r.get("status"), str) and "budget reached" in r["status"]), "asr_error_count": sum(1 for r in results if isinstance(r.get("status"), str) and r["status"].startswith("ASR error")), "memorization_error_count": sum(1 for r in results if r.get("status") == "Memorization error"), "missing_count": sum(1 for r in results if r.get("status","").startswith("Missing")), "extra_count": sum(1 for r in results if r.get("status","").startswith("Extra")), "total_tokens": len(results) } return results, corrected_text, stats # ========================= # ROUGE-L / WER-like / Guard # ========================= def lcs_len(a, b): m, n = len(a), len(b) dp = [[0]*(n+1) for _ in range(m+1)] for i in range(1, m+1): ai = a[i-1] for j in range(1, n+1): if ai == b[j-1]: dp[i][j] = dp[i-1][j-1] + 1 else: dp[i][j] = dp[i-1][j] if dp[i-1][j] >= dp[i][j-1] else dp[i][j-1] return dp[m][n] def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2): if not ref_tokens or not hyp_tokens: return 0.0, 0.0, 0.0 lcs = lcs_len(ref_tokens, hyp_tokens) prec = lcs / len(hyp_tokens) rec = lcs / len(ref_tokens) if prec == 0 and rec == 0: return 0.0, 0.0, 0.0 f1 = ((1+beta**2) * prec * rec) / (rec + beta**2 * prec + 1e-12) return float(f1), float(prec), float(rec) def compute_wer_like(aligned, ref_tokens_len): S = D = I = 0 for op in aligned: if op['type'] == 'replace': S += max(len(op['ref']), len(op['hyp'])) elif op['type'] == 'delete': D += len(op['ref']) elif op['type'] == 'insert': I += len(op['hyp']) N = max(ref_tokens_len, 1) return (S + D + I) / N def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model): sbert_sim_text = float(util.pytorch_cos_sim( sbert_model.encode(_normalize_for_models(original_text), convert_to_tensor=True), sbert_model.encode(_normalize_for_models(asr_text), convert_to_tensor=True) )) rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens) equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal') equal_ratio = equal_tokens / max(len(ref_tokens), 1) wer = compute_wer_like(aligned, len(ref_tokens)) off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65)) L = len(hyp_tokens) if off_topic: budget = 0 elif sbert_sim_text < 0.80 or rouge_f1 < 0.55: budget = int(0.15 * L) else: budget = int(0.40 * L) metrics = { "sbert_sim_text": round(sbert_sim_text, 3), "rougeL_f1": round(rouge_f1, 3), "rougeL_prec": round(rouge_p, 3), "rougeL_rec": round(rouge_r, 3), "equal_ratio": round(equal_ratio, 3), "wer_like": round(wer, 3), } print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True) return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics} # ========================= # Scores # ========================= def literal_similarity(original, recited): def norm(t): t = re.sub(r'[ًٌٍَُِّْـ]', '', t) t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t) t = re.sub(r'\s+', ' ', t).strip() return t o = norm(original); r = norm(recited) lev = textdistance.levenshtein.normalized_similarity(o, r) ot = simple_tokenize(o); rt = simple_tokenize(r) common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2) word_overlap = common / max(len(ot), 1) try: import nltk.translate.bleu_score as bleu bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0 except Exception: bleu1 = 0.0 final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1 return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3), "bleu1": round(bleu1,3), "literal_score": round(final_score,3)} def semantic_similarity(original, recited, use_marbert=FORCE_USE_MARBERT): sbert_sim = float(util.pytorch_cos_sim( _SBERT.encode(_normalize_for_models(original), convert_to_tensor=True), _SBERT.encode(_normalize_for_models(recited), convert_to_tensor=True) )) marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0 return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3), "semantic_score": round(max(sbert_sim, marbert_sim),3)} # ========================= # Audio helper # ========================= def ensure_audio_path(audio): if isinstance(audio, str): if not os.path.exists(audio): raise FileNotFoundError(f"Audio path not found: {audio}") return audio if isinstance(audio, tuple) and len(audio) == 2: data, sr = audio if isinstance(data, np.ndarray): tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) sf.write(tmp.name, data, sr) return tmp.name raise ValueError("Unsupported audio input format") # ========================= # Pipeline (robust errors + logs) # ========================= def transcribe_and_evaluate(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True): try: if not original_text or not original_text.strip(): raise ValueError("Original text is empty.") # Forced settings whisper_size = FORCE_WHISPER_NAME compute_type = FORCE_COMPUTE_TYPE use_marbert = FORCE_USE_MARBERT print(f"[RUN] whisper={whisper_size}, compute={compute_type}, marbert={use_marbert}", flush=True) load_models(whisper_name=whisper_size, whisper_compute=compute_type, use_marbert=use_marbert) audio_path = ensure_audio_path(audio) print(f"[AUDIO] path={audio_path}", flush=True) segments, info = _WHISPER.transcribe(audio_path, **ASR_OPTS) segments = list(segments) print(f"[ASR] segments={len(segments)}", flush=True) # Build ASR text from words words = [] for seg in segments: for w in (seg.words or []): tok = clean_ar_token(w.word) if tok: words.append(tok) asr_text = " ".join(words) # Tokens & alignment ref_tokens = simple_tokenize(original_text) hyp_tokens = simple_tokenize(asr_text) aligned = align_texts(ref_tokens, hyp_tokens) # Guard & budget guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT) off_topic = guard["off_topic"] guard_metrics = guard["metrics"] if FORCE_BUDGET_MODE == "off": budget_tokens = None guard_note = "budget_off" elif FORCE_BUDGET_MODE == "fixed": budget_tokens = int(FIXED_BUDGET_TOKENS) guard_note = f"budget_fixed_{budget_tokens}" elif FORCE_BUDGET_MODE == "ratio": budget_tokens = int(BUDGET_RATIO * len(hyp_tokens)) guard_note = f"budget_ratio_{BUDGET_RATIO}" else: budget_tokens = guard["budget_tokens"] guard_note = "off-topic" if off_topic else "ok" print(f"[BUDGET] mode={FORCE_BUDGET_MODE}, budget={budget_tokens}, note={guard_note}", flush=True) # Word-level confidences df_words = extract_word_conf_table(segments) asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens) print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True) # Classification results, corrected_text, local_stats = classify_alignment_optimized( aligned, ref_tokens, hyp_tokens, bert_thresh=0.75, max_bert=0.85, asr_token_conf=asr_token_conf, low_high=(low_t, high_t), replace_budget_tokens=budget_tokens, guard_note=guard_note ) # Scores lit = literal_similarity(original_text, corrected_text) sem = semantic_similarity(original_text, corrected_text, use_marbert=use_marbert) # Extra global metrics for report all_probs = df_words["prob"].dropna().tolist() conf_summary = { "num_words_with_prob": int(len(all_probs)), "avg_prob": None if not all_probs else float(np.mean(all_probs)), "p15": None if not all_probs else float(np.quantile(all_probs, 0.15)), "p70": None if not all_probs else float(np.quantile(all_probs, 0.70)), } df = pd.DataFrame(results) report = { "requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert}, "effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert}, "guard": {"mode": FORCE_BUDGET_MODE, "off_topic": off_topic, "budget_tokens": None if budget_tokens is None else int(budget_tokens), **guard_metrics}, "local_stats": local_stats, "confidence_summary": conf_summary, "original_text": original_text, "asr_text": asr_text, "corrected_text": corrected_text, "literal": lit, "semantic": sem, "low_t": float(low_t), "high_t": float(high_t), } return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df except Exception as e: tb = traceback.format_exc() print("ERROR in transcribe_and_evaluate:\n", tb, flush=True) empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}]) err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2) gr.Warning(str(e)) return "", "", err_json, empty_df def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True): corrected_text, asr_text, report_json, df = transcribe_and_evaluate( audio, original_text, whisper_size, compute_type, vad, use_marbert ) try: return json.loads(report_json) except Exception: return {"error": "Failed to parse report_json."} # ========================= # Gradio UI # ========================= def build_ui(): with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo: gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)") with gr.Row(): audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio") original = gr.Textbox(lines=8, label="Original Text (Ground Truth)") with gr.Row(): whisper_size = gr.Dropdown(choices=["large-v3"], value="large-v3", label="Whisper model size (forced)") compute_type = gr.Dropdown(choices=["int8"], value="int8", label="compute_type (forced)") vad = gr.Checkbox(value=True, label="VAD filter") use_marbert = gr.Checkbox(value=True, label="Use MARBERT (forced)") btn = gr.Button("Transcribe & Evaluate", variant="primary") corrected = gr.Textbox(lines=6, label="Corrected Transcript (ASR errors restored)") asr_out = gr.Textbox(lines=6, label="Raw ASR Transcript") report = gr.JSON(label="Report (scores & thresholds)") table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"], label="Token-level Decisions", wrap=True) btn.click( fn=transcribe_and_evaluate, inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], outputs=[corrected, asr_out, report, table], api_name="evaluate" ) gr.Button(visible=False).click( fn=api_predict, inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], outputs=gr.JSON(), api_name="predict" ) return demo if __name__ == "__main__": demo = build_ui() demo.launch()