MuhammadHijazii's picture
Update app.py
60cff39 verified
import os, re, json, math, tempfile, traceback
import numpy as np
import pandas as pd
import torch
import textdistance
import gradio as gr
from faster_whisper import WhisperModel
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel
import soundfile as sf
# =========================
# Global config (forced per your request)
# =========================
FORCE_WHISPER_NAME = "large-v3"
FORCE_COMPUTE_TYPE = "int8"
FORCE_USE_MARBERT = True
# ======= Budget Config =======
# "auto": يعتمد على الحارس العالمي (SBERT/ROUGE/WER)
# "fixed": عدد ثابت من الاستبدالات (0 يعني عدم استبدال مطلقًا)
# "ratio": نسبة من طول النص المنطوق
# "off": بدون سقف (سلوك قديم)
FORCE_BUDGET_MODE = "auto" # "auto" | "fixed" | "ratio" | "off"
FIXED_BUDGET_TOKENS = 0
BUDGET_RATIO = 0.15
# =============================
# خيارات تفريغ ثابتة لتقليل الفروقات
ASR_OPTS = dict(
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 200},
beam_size=5,
best_of=5,
temperature=0.0,
)
# =========================
# Device
# =========================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[INIT] DEVICE={DEVICE}", flush=True)
# =========================
# Lazy models
# =========================
_SBERT = None
_MARBERT_TOK = None
_MARBERT = None
_WHISPER = None
def load_models(
sbert_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
marbert_name="UBC-NLP/MARBERT",
whisper_name=FORCE_WHISPER_NAME,
whisper_compute=FORCE_COMPUTE_TYPE,
use_marbert=FORCE_USE_MARBERT
):
"""Load models once; forced config respected even on CPU."""
global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER
if _SBERT is None:
_SBERT = SentenceTransformer(sbert_name, device=("cuda" if DEVICE=="cuda" else "cpu"))
print(f"[LOAD] SBERT: {sbert_name}", flush=True)
if _MARBERT is None and use_marbert:
_MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
_MARBERT = AutoModel.from_pretrained(marbert_name).to(("cuda" if DEVICE=="cuda" else "cpu"))
_MARBERT.eval()
print(f"[LOAD] MARBERT: {marbert_name} (device={DEVICE})", flush=True)
if _WHISPER is None:
_WHISPER = WhisperModel(whisper_name, device=("cuda" if DEVICE=="cuda" else "cpu"),
compute_type=whisper_compute)
print(f"[LOAD] Whisper: {whisper_name} (compute={whisper_compute})", flush=True)
# =========================
# Normalization / Tokenization / Alignment
# =========================
def normalize_ar_orth(text: str) -> str:
# تطبيع عام للمحاذاة
text = re.sub(r"[ًٌٍَُِّْـ]", "", text)
text = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def _normalize_for_models(s: str) -> str:
# تطبيع خاص لمدخلات SBERT/MARBERT
s = re.sub(r"[ًٌٍَُِّْـ]", "", s)
s = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", s)
s = re.sub(r"\s+", " ", s).strip()
return s
def simple_tokenize(text: str):
t = normalize_ar_orth(text)
try:
import nltk
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt', quiet=True)
return nltk.word_tokenize(t)
except Exception:
return t.split()
def align_texts(ref_tokens, hyp_tokens):
import difflib
sm = difflib.SequenceMatcher(None, ref_tokens, hyp_tokens)
aligned = []
for tag, i1, i2, j1, j2 in sm.get_opcodes():
aligned.append({
'type': tag,
'ref': ref_tokens[i1:i2],
'hyp': hyp_tokens[j1:j2],
'ref_idx': (i1, i2),
'hyp_idx': (j1, j2)
})
return aligned
# =========================
# Phonetic / Levenshtein
# =========================
def arabic_soundex(word):
w = normalize_ar_orth(word)
groups = {
'b': 'بف', 'j': 'جشص', 'd': 'دض', 't': 'طت', 'q': 'قغ', 'k': 'كخ',
's': 'سصز', 'z': 'ثذظ', 'h': 'ح', 'a': 'ع', 'm': 'م', 'n': 'ن',
'l': 'ل', 'r': 'ر', 'w': 'و', 'y': 'ي'
}
code = []
for ch in w:
for rep, chars in groups.items():
if ch in chars:
code.append(rep); break
return "".join(code)
def phonetic_similarity(w1, w2):
if not w1 or not w2: return False
return arabic_soundex(w1) == arabic_soundex(w2)
def is_levenshtein_1(w1, w2):
if not w1 or not w2: return False
return textdistance.levenshtein(w1, w2) == 1
# =========================
# Numbers
# =========================
AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")
UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2,
"ثلاث":3,"ثلاثه":3,"ثلاثة":3,"اربع":4,"اربعه":4,"أربع":4,"أربعه":4,
"خمس":5,"خمسه":5,"ست":6,"سته":6,"سبع":7,"سبعه":7,"ثمان":8,"ثماني":8,"ثمانيه":8,
"تسع":9,"تسعه":9}
TENS = {"عشر":10,"عشرة":10,"عشره":10,"عشرون":20,"عشرين":20,"ثلاثون":30,"ثلاثين":30,
"اربعون":40,"أربعون":40,"اربعين":40,"خمسون":50,"ستون":60,"سبعون":70,"ثمانون":80,"تسعون":90}
HUND = {"مئه":100,"مئة":100,"مائه":100}
SCALE = {"الف":1000,"ألف":1000,"آلاف":1000,"مليون":10**6,"مليار":10**9}
def normalize_digits(s: str) -> str:
return s.translate(AR_DIGITS)
def words_to_number(tokens):
total = 0; current = 0
for w in tokens:
if w in UNITS: current += UNITS[w]
elif w in TENS: current += TENS[w]
elif w in HUND: current += HUND[w]
elif w in SCALE:
current = max(1, current) * SCALE[w]
total += current; current = 0
elif w == "و":
continue
else:
total += current; current = 0
total += current
return total if total != 0 else None
def to_numeric_value(token: str):
if not token: return None
t = normalize_ar_orth(token)
d = normalize_digits(t)
if re.fullmatch(r"\d+", d):
return int(d)
toks = t.split()
return words_to_number(toks)
# =========================
# Semantic similarities (MARBERT fixed)
# =========================
def _mean_pool(last_hidden_state, attention_mask):
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp(min=1e-9)
return summed / counts
def marbert_cls_similarity(a: str, b: str) -> float:
"""Return 0 when [UNK] dominates; use mean pooling instead of CLS only."""
if not a or not b or _MARBERT is None:
return 0.0
a_n = _normalize_for_models(a)
b_n = _normalize_for_models(b)
# UNK ratio check
ids_a = _MARBERT_TOK(a_n, add_special_tokens=False).input_ids
ids_b = _MARBERT_TOK(b_n, add_special_tokens=False).input_ids
unk_id = _MARBERT_TOK.unk_token_id
if len(ids_a) == 0 or len(ids_b) == 0:
return 0.0
unk_ratio_a = (ids_a.count(unk_id) / len(ids_a)) if unk_id is not None else 0.0
unk_ratio_b = (ids_b.count(unk_id) / len(ids_b)) if unk_id is not None else 0.0
if max(unk_ratio_a, unk_ratio_b) > 0.5:
# too many unknowns → ignore MARBERT
return 0.0
with torch.no_grad():
ta = _MARBERT_TOK(a_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
tb = _MARBERT_TOK(b_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu"))
ea = _mean_pool(_MARBERT(**ta).last_hidden_state, ta["attention_mask"])
eb = _mean_pool(_MARBERT(**tb).last_hidden_state, tb["attention_mask"])
sim = util.cos_sim(ea, eb).item() # -1..1
return (sim + 1) / 2 # 0..1
def multi_bert_similarity(a: str, b: str):
if not a or not b:
return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0, "note":"empty"}
a_n = _normalize_for_models(a); b_n = _normalize_for_models(b)
sbert_sim = float(util.pytorch_cos_sim(
_SBERT.encode(a_n, convert_to_tensor=True),
_SBERT.encode(b_n, convert_to_tensor=True)
))
marbert_sim = marbert_cls_similarity(a_n, b_n)
note = None
if abs(sbert_sim - marbert_sim) > 0.35:
note = "models_disagree"
vals = [sbert_sim, marbert_sim]
return {"sbert": sbert_sim, "marbert": marbert_sim,
"max": max(vals), "avg": sum(vals)/len(vals), "note": note}
# =========================
# Faster-Whisper helpers
# =========================
def clean_ar_token(t: str) -> str:
t = t.strip()
t = re.sub(r'^[^\w\u0600-\u06FF]+|[^\w\u0600-\u06FF]+$', '', t)
t = normalize_ar_orth(t)
return t
def extract_word_conf_table(segments):
rows = []
for seg in segments:
for w in (seg.words or []):
rows.append({
"seg_start": float(seg.start),
"seg_end": float(seg.end),
"word_start": float(w.start),
"word_end": float(w.end),
"word": clean_ar_token(w.word),
"prob": float(w.probability),
})
return pd.DataFrame(rows)
def build_asr_token_conf(df_words: pd.DataFrame, hyp_tokens: list):
toks_probs, toks_durs = [], []
for _, row in df_words.iterrows():
prob = row["prob"]
dur = (row["word_end"] - row["word_start"]) * 1000.0
toks_probs.append(prob)
toks_durs.append(dur)
L = len(hyp_tokens)
if len(toks_probs) >= L:
toks_probs = toks_probs[:L]
toks_durs = toks_durs[:L]
else:
pad = L - len(toks_probs)
toks_probs += [None]*pad
toks_durs += [None]*pad
arr = np.array([p for p in toks_probs if p is not None])
if arr.size:
low_t = float(np.quantile(arr, 0.15))
high_t = float(np.quantile(arr, 0.70))
else:
low_t, high_t = 0.5, 0.85
asr_token_conf = {i: {"prob": toks_probs[i], "duration_ms": toks_durs[i]} for i in range(L)}
return asr_token_conf, low_t, high_t
# =========================
# Confidence gate
# =========================
def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float,
is_short: bool, lev1: bool, duration_ms: float = None,
low_t: float = 0.6, high_t: float = 0.9, sbert_lo=0.60):
band = "mid"
if prob is not None:
if prob <= low_t: band = "low"
elif prob >= high_t: band = "high"
very_short = (duration_ms is not None and duration_ms < 120)
if band == "low":
if is_short and lev1: return 'ASR error (low p + short+lev1)'
if very_short: return 'ASR error (low p + very short)'
if sbert_sim >= sbert_lo: return 'ASR error (low p + semantic)'
return 'ASR error (low p)'
if band == "high":
return base_decision
return base_decision
# =========================
# Pair + main classifiers (tightened)
# =========================
def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
bert_thresh=0.75, max_bert=0.85):
# numbers equal
ref_num = to_numeric_value(ref_w)
hyp_num = to_numeric_value(hyp_w)
if (ref_num is not None) or (hyp_num is not None):
if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
return 'ASR error (numbers equal)'
# short+lev1
if short_word and lev1:
return 'ASR error (short+lev1)'
# semantic/phonetic
sbert_ok = bert_scores["sbert"] >= 0.80
avg_ok = bert_scores["avg"] >= bert_thresh
max_ok = (bert_scores["max"] > max_bert) and sbert_ok
disagree = (bert_scores.get("note") == "models_disagree")
if not disagree:
if ((phon_sim or lev1) and avg_ok) or max_ok:
return 'ASR error (semantic/phonetic)'
else:
if phon_sim or lev1:
if sbert_ok and avg_ok:
return 'ASR error (semantic/phonetic)'
else:
if bert_scores["sbert"] >= 0.80:
return 'ASR error (semantic)'
return 'Memorization error'
def classify_alignment_optimized(
aligned, ref_tokens, hyp_tokens,
bert_thresh=0.75, max_bert=0.85,
asr_token_conf=None, low_high=None,
replace_budget_tokens=None, # سقف الاستبدال
guard_note=None # وسم مثل "off-topic"/"ok"/"budget_off"
):
# thresholds من احتمالات الكلمات
if low_high is None:
if asr_token_conf:
probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
if probs:
low_t = float(np.quantile(probs, 0.15))
high_t = float(np.quantile(probs, 0.70))
else:
low_t, high_t = 0.5, 0.85
else:
low_t, high_t = 0.5, 0.85
else:
low_t, high_t = low_high
results, corrected_words = [], []
replaced_count = 0
for entry in aligned:
tag = entry['type']
i1, i2 = entry.get('ref_idx', (None, None))
j1, j2 = entry.get('hyp_idx', (None, None))
if tag == 'equal':
for ref_w, hyp_w in zip(entry['ref'], entry['hyp']):
results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': '', 'used': hyp_w})
corrected_words.append(hyp_w)
elif tag in ['replace', 'delete', 'insert']:
max_len = max(len(entry['ref']), len(entry['hyp']))
for k in range(max_len):
ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
if not ref_w and not hyp_w:
continue
# similarities
phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0}
short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6)
# base status
if ref_w and hyp_w:
base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
bert_thresh, max_bert)
elif hyp_w == '':
base_status = 'Missing (possible omission)'
elif ref_w == '':
base_status = 'Extra (possible ASR insertion)'
else:
base_status = 'Undefined Case'
# word-level confidence gate
word_prob = None; word_dur = None
if (j1 is not None) and (j2 is not None):
hyp_abs_idx = j1 + k
if asr_token_conf and hyp_abs_idx in asr_token_conf:
word_prob = asr_token_conf[hyp_abs_idx].get("prob")
word_dur = asr_token_conf[hyp_abs_idx].get("duration_ms")
final_status = base_status
if ref_w and hyp_w:
final_status = gate_by_word_conf(
base_decision=base_status, prob=word_prob,
sbert_sim=bert_scores["sbert"],
is_short=short_word, lev1=lev1,
duration_ms=word_dur,
low_t=low_t, high_t=high_t, sbert_lo=0.60
)
# choose token with budget
used = hyp_w
budget_info = ""
if ref_w and hyp_w:
if final_status.startswith("ASR error"):
if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens):
used = ref_w
replaced_count += 1
if replace_budget_tokens is not None:
budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
else:
used = hyp_w
final_status += " [guard: budget reached]"
budget_info = f", budget={replaced_count}/{replace_budget_tokens}"
else:
used = hyp_w
elif hyp_w == '':
used = ''
elif ref_w == '':
used = hyp_w
reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
f'SBERT={bert_scores["sbert"]:.2f}, '
f'MARBERT={bert_scores["marbert"]:.2f}, '
f'MAX={bert_scores["max"]:.2f}, '
f'AVG={bert_scores["avg"]:.2f}, short={short_word}, '
f'prob={None if word_prob is None else round(word_prob,2)}, '
f'dur_ms={None if word_dur is None else int(word_dur)}, '
f'low_t={round(low_t,2)}, high_t={round(high_t,2)}')
if bert_scores.get("note"):
reason += f", note={bert_scores['note']}"
if guard_note:
reason += f", guard='{guard_note}'"
if budget_info:
reason += budget_info
results.append({
'ASR_word': hyp_w, 'GT_word': ref_w,
'status': final_status, 'reason': reason, 'used': used
})
if used:
corrected_words.append(used)
corrected_text = " ".join([w for w in corrected_words if w])
# إحصاءات محلية مفيدة للتقرير
stats = {
"replacements_made": sum(1 for r in results
if r.get("used") and r.get("GT_word") and r["used"] == r["GT_word"]
and r.get("ASR_word") and r["ASR_word"] != r["GT_word"]),
"budget_reached_count": sum(1 for r in results if isinstance(r.get("status"), str) and "budget reached" in r["status"]),
"asr_error_count": sum(1 for r in results if isinstance(r.get("status"), str) and r["status"].startswith("ASR error")),
"memorization_error_count": sum(1 for r in results if r.get("status") == "Memorization error"),
"missing_count": sum(1 for r in results if r.get("status","").startswith("Missing")),
"extra_count": sum(1 for r in results if r.get("status","").startswith("Extra")),
"total_tokens": len(results)
}
return results, corrected_text, stats
# =========================
# ROUGE-L / WER-like / Guard
# =========================
def lcs_len(a, b):
m, n = len(a), len(b)
dp = [[0]*(n+1) for _ in range(m+1)]
for i in range(1, m+1):
ai = a[i-1]
for j in range(1, n+1):
if ai == b[j-1]:
dp[i][j] = dp[i-1][j-1] + 1
else:
dp[i][j] = dp[i-1][j] if dp[i-1][j] >= dp[i][j-1] else dp[i][j-1]
return dp[m][n]
def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2):
if not ref_tokens or not hyp_tokens:
return 0.0, 0.0, 0.0
lcs = lcs_len(ref_tokens, hyp_tokens)
prec = lcs / len(hyp_tokens)
rec = lcs / len(ref_tokens)
if prec == 0 and rec == 0:
return 0.0, 0.0, 0.0
f1 = ((1+beta**2) * prec * rec) / (rec + beta**2 * prec + 1e-12)
return float(f1), float(prec), float(rec)
def compute_wer_like(aligned, ref_tokens_len):
S = D = I = 0
for op in aligned:
if op['type'] == 'replace':
S += max(len(op['ref']), len(op['hyp']))
elif op['type'] == 'delete':
D += len(op['ref'])
elif op['type'] == 'insert':
I += len(op['hyp'])
N = max(ref_tokens_len, 1)
return (S + D + I) / N
def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model):
sbert_sim_text = float(util.pytorch_cos_sim(
sbert_model.encode(_normalize_for_models(original_text), convert_to_tensor=True),
sbert_model.encode(_normalize_for_models(asr_text), convert_to_tensor=True)
))
rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens)
equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal')
equal_ratio = equal_tokens / max(len(ref_tokens), 1)
wer = compute_wer_like(aligned, len(ref_tokens))
off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65))
L = len(hyp_tokens)
if off_topic:
budget = 0
elif sbert_sim_text < 0.80 or rouge_f1 < 0.55:
budget = int(0.15 * L)
else:
budget = int(0.40 * L)
metrics = {
"sbert_sim_text": round(sbert_sim_text, 3),
"rougeL_f1": round(rouge_f1, 3),
"rougeL_prec": round(rouge_p, 3),
"rougeL_rec": round(rouge_r, 3),
"equal_ratio": round(equal_ratio, 3),
"wer_like": round(wer, 3),
}
print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True)
return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics}
# =========================
# Scores
# =========================
def literal_similarity(original, recited):
def norm(t):
t = re.sub(r'[ًٌٍَُِّْـ]', '', t)
t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t)
t = re.sub(r'\s+', ' ', t).strip()
return t
o = norm(original); r = norm(recited)
lev = textdistance.levenshtein.normalized_similarity(o, r)
ot = simple_tokenize(o); rt = simple_tokenize(r)
common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
word_overlap = common / max(len(ot), 1)
try:
import nltk.translate.bleu_score as bleu
bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
except Exception:
bleu1 = 0.0
final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
"bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
def semantic_similarity(original, recited, use_marbert=FORCE_USE_MARBERT):
sbert_sim = float(util.pytorch_cos_sim(
_SBERT.encode(_normalize_for_models(original), convert_to_tensor=True),
_SBERT.encode(_normalize_for_models(recited), convert_to_tensor=True)
))
marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0
return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
"semantic_score": round(max(sbert_sim, marbert_sim),3)}
# =========================
# Audio helper
# =========================
def ensure_audio_path(audio):
if isinstance(audio, str):
if not os.path.exists(audio):
raise FileNotFoundError(f"Audio path not found: {audio}")
return audio
if isinstance(audio, tuple) and len(audio) == 2:
data, sr = audio
if isinstance(data, np.ndarray):
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
sf.write(tmp.name, data, sr)
return tmp.name
raise ValueError("Unsupported audio input format")
# =========================
# Pipeline (robust errors + logs)
# =========================
def transcribe_and_evaluate(audio, original_text, whisper_size=None,
compute_type=None, vad=True, use_marbert=True):
try:
if not original_text or not original_text.strip():
raise ValueError("Original text is empty.")
# Forced settings
whisper_size = FORCE_WHISPER_NAME
compute_type = FORCE_COMPUTE_TYPE
use_marbert = FORCE_USE_MARBERT
print(f"[RUN] whisper={whisper_size}, compute={compute_type}, marbert={use_marbert}", flush=True)
load_models(whisper_name=whisper_size, whisper_compute=compute_type, use_marbert=use_marbert)
audio_path = ensure_audio_path(audio)
print(f"[AUDIO] path={audio_path}", flush=True)
segments, info = _WHISPER.transcribe(audio_path, **ASR_OPTS)
segments = list(segments)
print(f"[ASR] segments={len(segments)}", flush=True)
# Build ASR text from words
words = []
for seg in segments:
for w in (seg.words or []):
tok = clean_ar_token(w.word)
if tok:
words.append(tok)
asr_text = " ".join(words)
# Tokens & alignment
ref_tokens = simple_tokenize(original_text)
hyp_tokens = simple_tokenize(asr_text)
aligned = align_texts(ref_tokens, hyp_tokens)
# Guard & budget
guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT)
off_topic = guard["off_topic"]
guard_metrics = guard["metrics"]
if FORCE_BUDGET_MODE == "off":
budget_tokens = None
guard_note = "budget_off"
elif FORCE_BUDGET_MODE == "fixed":
budget_tokens = int(FIXED_BUDGET_TOKENS)
guard_note = f"budget_fixed_{budget_tokens}"
elif FORCE_BUDGET_MODE == "ratio":
budget_tokens = int(BUDGET_RATIO * len(hyp_tokens))
guard_note = f"budget_ratio_{BUDGET_RATIO}"
else:
budget_tokens = guard["budget_tokens"]
guard_note = "off-topic" if off_topic else "ok"
print(f"[BUDGET] mode={FORCE_BUDGET_MODE}, budget={budget_tokens}, note={guard_note}", flush=True)
# Word-level confidences
df_words = extract_word_conf_table(segments)
asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True)
# Classification
results, corrected_text, local_stats = classify_alignment_optimized(
aligned, ref_tokens, hyp_tokens,
bert_thresh=0.75, max_bert=0.85,
asr_token_conf=asr_token_conf, low_high=(low_t, high_t),
replace_budget_tokens=budget_tokens,
guard_note=guard_note
)
# Scores
lit = literal_similarity(original_text, corrected_text)
sem = semantic_similarity(original_text, corrected_text, use_marbert=use_marbert)
# Extra global metrics for report
all_probs = df_words["prob"].dropna().tolist()
conf_summary = {
"num_words_with_prob": int(len(all_probs)),
"avg_prob": None if not all_probs else float(np.mean(all_probs)),
"p15": None if not all_probs else float(np.quantile(all_probs, 0.15)),
"p70": None if not all_probs else float(np.quantile(all_probs, 0.70)),
}
df = pd.DataFrame(results)
report = {
"requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
"effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert},
"guard": {"mode": FORCE_BUDGET_MODE, "off_topic": off_topic, "budget_tokens": None if budget_tokens is None else int(budget_tokens), **guard_metrics},
"local_stats": local_stats,
"confidence_summary": conf_summary,
"original_text": original_text,
"asr_text": asr_text,
"corrected_text": corrected_text,
"literal": lit,
"semantic": sem,
"low_t": float(low_t), "high_t": float(high_t),
}
return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df
except Exception as e:
tb = traceback.format_exc()
print("ERROR in transcribe_and_evaluate:\n", tb, flush=True)
empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}])
err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2)
gr.Warning(str(e))
return "", "", err_json, empty_df
def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True):
corrected_text, asr_text, report_json, df = transcribe_and_evaluate(
audio, original_text, whisper_size, compute_type, vad, use_marbert
)
try:
return json.loads(report_json)
except Exception:
return {"error": "Failed to parse report_json."}
# =========================
# Gradio UI
# =========================
def build_ui():
with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo:
gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)")
with gr.Row():
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
with gr.Row():
whisper_size = gr.Dropdown(choices=["large-v3"], value="large-v3", label="Whisper model size (forced)")
compute_type = gr.Dropdown(choices=["int8"], value="int8", label="compute_type (forced)")
vad = gr.Checkbox(value=True, label="VAD filter")
use_marbert = gr.Checkbox(value=True, label="Use MARBERT (forced)")
btn = gr.Button("Transcribe & Evaluate", variant="primary")
corrected = gr.Textbox(lines=6, label="Corrected Transcript (ASR errors restored)")
asr_out = gr.Textbox(lines=6, label="Raw ASR Transcript")
report = gr.JSON(label="Report (scores & thresholds)")
table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"],
label="Token-level Decisions", wrap=True)
btn.click(
fn=transcribe_and_evaluate,
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
outputs=[corrected, asr_out, report, table],
api_name="evaluate"
)
gr.Button(visible=False).click(
fn=api_predict,
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
outputs=gr.JSON(),
api_name="predict"
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch()