|
|
import os, re, json, math, tempfile, traceback |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import textdistance |
|
|
|
|
|
import gradio as gr |
|
|
from faster_whisper import WhisperModel |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FORCE_WHISPER_NAME = "large-v3" |
|
|
FORCE_COMPUTE_TYPE = "int8" |
|
|
FORCE_USE_MARBERT = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FORCE_BUDGET_MODE = "auto" |
|
|
FIXED_BUDGET_TOKENS = 0 |
|
|
BUDGET_RATIO = 0.15 |
|
|
|
|
|
|
|
|
|
|
|
ASR_OPTS = dict( |
|
|
word_timestamps=True, |
|
|
vad_filter=True, |
|
|
vad_parameters={"min_silence_duration_ms": 200}, |
|
|
beam_size=5, |
|
|
best_of=5, |
|
|
temperature=0.0, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"[INIT] DEVICE={DEVICE}", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SBERT = None |
|
|
_MARBERT_TOK = None |
|
|
_MARBERT = None |
|
|
_WHISPER = None |
|
|
|
|
|
def load_models( |
|
|
sbert_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
|
marbert_name="UBC-NLP/MARBERT", |
|
|
whisper_name=FORCE_WHISPER_NAME, |
|
|
whisper_compute=FORCE_COMPUTE_TYPE, |
|
|
use_marbert=FORCE_USE_MARBERT |
|
|
): |
|
|
"""Load models once; forced config respected even on CPU.""" |
|
|
global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER |
|
|
|
|
|
if _SBERT is None: |
|
|
_SBERT = SentenceTransformer(sbert_name, device=("cuda" if DEVICE=="cuda" else "cpu")) |
|
|
print(f"[LOAD] SBERT: {sbert_name}", flush=True) |
|
|
|
|
|
if _MARBERT is None and use_marbert: |
|
|
_MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name) |
|
|
_MARBERT = AutoModel.from_pretrained(marbert_name).to(("cuda" if DEVICE=="cuda" else "cpu")) |
|
|
_MARBERT.eval() |
|
|
print(f"[LOAD] MARBERT: {marbert_name} (device={DEVICE})", flush=True) |
|
|
|
|
|
if _WHISPER is None: |
|
|
_WHISPER = WhisperModel(whisper_name, device=("cuda" if DEVICE=="cuda" else "cpu"), |
|
|
compute_type=whisper_compute) |
|
|
print(f"[LOAD] Whisper: {whisper_name} (compute={whisper_compute})", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_ar_orth(text: str) -> str: |
|
|
|
|
|
text = re.sub(r"[ًٌٍَُِّْـ]", "", text) |
|
|
text = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", text) |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
return text |
|
|
|
|
|
def _normalize_for_models(s: str) -> str: |
|
|
|
|
|
s = re.sub(r"[ًٌٍَُِّْـ]", "", s) |
|
|
s = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", s) |
|
|
s = re.sub(r"\s+", " ", s).strip() |
|
|
return s |
|
|
|
|
|
def simple_tokenize(text: str): |
|
|
t = normalize_ar_orth(text) |
|
|
try: |
|
|
import nltk |
|
|
try: |
|
|
nltk.data.find('tokenizers/punkt') |
|
|
except LookupError: |
|
|
nltk.download('punkt', quiet=True) |
|
|
return nltk.word_tokenize(t) |
|
|
except Exception: |
|
|
return t.split() |
|
|
|
|
|
def align_texts(ref_tokens, hyp_tokens): |
|
|
import difflib |
|
|
sm = difflib.SequenceMatcher(None, ref_tokens, hyp_tokens) |
|
|
aligned = [] |
|
|
for tag, i1, i2, j1, j2 in sm.get_opcodes(): |
|
|
aligned.append({ |
|
|
'type': tag, |
|
|
'ref': ref_tokens[i1:i2], |
|
|
'hyp': hyp_tokens[j1:j2], |
|
|
'ref_idx': (i1, i2), |
|
|
'hyp_idx': (j1, j2) |
|
|
}) |
|
|
return aligned |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arabic_soundex(word): |
|
|
w = normalize_ar_orth(word) |
|
|
groups = { |
|
|
'b': 'بف', 'j': 'جشص', 'd': 'دض', 't': 'طت', 'q': 'قغ', 'k': 'كخ', |
|
|
's': 'سصز', 'z': 'ثذظ', 'h': 'ح', 'a': 'ع', 'm': 'م', 'n': 'ن', |
|
|
'l': 'ل', 'r': 'ر', 'w': 'و', 'y': 'ي' |
|
|
} |
|
|
code = [] |
|
|
for ch in w: |
|
|
for rep, chars in groups.items(): |
|
|
if ch in chars: |
|
|
code.append(rep); break |
|
|
return "".join(code) |
|
|
|
|
|
def phonetic_similarity(w1, w2): |
|
|
if not w1 or not w2: return False |
|
|
return arabic_soundex(w1) == arabic_soundex(w2) |
|
|
|
|
|
def is_levenshtein_1(w1, w2): |
|
|
if not w1 or not w2: return False |
|
|
return textdistance.levenshtein(w1, w2) == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789") |
|
|
UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2, |
|
|
"ثلاث":3,"ثلاثه":3,"ثلاثة":3,"اربع":4,"اربعه":4,"أربع":4,"أربعه":4, |
|
|
"خمس":5,"خمسه":5,"ست":6,"سته":6,"سبع":7,"سبعه":7,"ثمان":8,"ثماني":8,"ثمانيه":8, |
|
|
"تسع":9,"تسعه":9} |
|
|
TENS = {"عشر":10,"عشرة":10,"عشره":10,"عشرون":20,"عشرين":20,"ثلاثون":30,"ثلاثين":30, |
|
|
"اربعون":40,"أربعون":40,"اربعين":40,"خمسون":50,"ستون":60,"سبعون":70,"ثمانون":80,"تسعون":90} |
|
|
HUND = {"مئه":100,"مئة":100,"مائه":100} |
|
|
SCALE = {"الف":1000,"ألف":1000,"آلاف":1000,"مليون":10**6,"مليار":10**9} |
|
|
|
|
|
def normalize_digits(s: str) -> str: |
|
|
return s.translate(AR_DIGITS) |
|
|
|
|
|
def words_to_number(tokens): |
|
|
total = 0; current = 0 |
|
|
for w in tokens: |
|
|
if w in UNITS: current += UNITS[w] |
|
|
elif w in TENS: current += TENS[w] |
|
|
elif w in HUND: current += HUND[w] |
|
|
elif w in SCALE: |
|
|
current = max(1, current) * SCALE[w] |
|
|
total += current; current = 0 |
|
|
elif w == "و": |
|
|
continue |
|
|
else: |
|
|
total += current; current = 0 |
|
|
total += current |
|
|
return total if total != 0 else None |
|
|
|
|
|
def to_numeric_value(token: str): |
|
|
if not token: return None |
|
|
t = normalize_ar_orth(token) |
|
|
d = normalize_digits(t) |
|
|
if re.fullmatch(r"\d+", d): |
|
|
return int(d) |
|
|
toks = t.split() |
|
|
return words_to_number(toks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mean_pool(last_hidden_state, attention_mask): |
|
|
mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
|
|
summed = (last_hidden_state * mask).sum(dim=1) |
|
|
counts = mask.sum(dim=1).clamp(min=1e-9) |
|
|
return summed / counts |
|
|
|
|
|
def marbert_cls_similarity(a: str, b: str) -> float: |
|
|
"""Return 0 when [UNK] dominates; use mean pooling instead of CLS only.""" |
|
|
if not a or not b or _MARBERT is None: |
|
|
return 0.0 |
|
|
|
|
|
a_n = _normalize_for_models(a) |
|
|
b_n = _normalize_for_models(b) |
|
|
|
|
|
|
|
|
ids_a = _MARBERT_TOK(a_n, add_special_tokens=False).input_ids |
|
|
ids_b = _MARBERT_TOK(b_n, add_special_tokens=False).input_ids |
|
|
unk_id = _MARBERT_TOK.unk_token_id |
|
|
if len(ids_a) == 0 or len(ids_b) == 0: |
|
|
return 0.0 |
|
|
unk_ratio_a = (ids_a.count(unk_id) / len(ids_a)) if unk_id is not None else 0.0 |
|
|
unk_ratio_b = (ids_b.count(unk_id) / len(ids_b)) if unk_id is not None else 0.0 |
|
|
if max(unk_ratio_a, unk_ratio_b) > 0.5: |
|
|
|
|
|
return 0.0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
ta = _MARBERT_TOK(a_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu")) |
|
|
tb = _MARBERT_TOK(b_n, return_tensors='pt', truncation=True, padding=True).to(("cuda" if DEVICE=="cuda" else "cpu")) |
|
|
ea = _mean_pool(_MARBERT(**ta).last_hidden_state, ta["attention_mask"]) |
|
|
eb = _mean_pool(_MARBERT(**tb).last_hidden_state, tb["attention_mask"]) |
|
|
sim = util.cos_sim(ea, eb).item() |
|
|
return (sim + 1) / 2 |
|
|
|
|
|
def multi_bert_similarity(a: str, b: str): |
|
|
if not a or not b: |
|
|
return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0, "note":"empty"} |
|
|
|
|
|
a_n = _normalize_for_models(a); b_n = _normalize_for_models(b) |
|
|
sbert_sim = float(util.pytorch_cos_sim( |
|
|
_SBERT.encode(a_n, convert_to_tensor=True), |
|
|
_SBERT.encode(b_n, convert_to_tensor=True) |
|
|
)) |
|
|
marbert_sim = marbert_cls_similarity(a_n, b_n) |
|
|
|
|
|
note = None |
|
|
if abs(sbert_sim - marbert_sim) > 0.35: |
|
|
note = "models_disagree" |
|
|
|
|
|
vals = [sbert_sim, marbert_sim] |
|
|
return {"sbert": sbert_sim, "marbert": marbert_sim, |
|
|
"max": max(vals), "avg": sum(vals)/len(vals), "note": note} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_ar_token(t: str) -> str: |
|
|
t = t.strip() |
|
|
t = re.sub(r'^[^\w\u0600-\u06FF]+|[^\w\u0600-\u06FF]+$', '', t) |
|
|
t = normalize_ar_orth(t) |
|
|
return t |
|
|
|
|
|
def extract_word_conf_table(segments): |
|
|
rows = [] |
|
|
for seg in segments: |
|
|
for w in (seg.words or []): |
|
|
rows.append({ |
|
|
"seg_start": float(seg.start), |
|
|
"seg_end": float(seg.end), |
|
|
"word_start": float(w.start), |
|
|
"word_end": float(w.end), |
|
|
"word": clean_ar_token(w.word), |
|
|
"prob": float(w.probability), |
|
|
}) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
def build_asr_token_conf(df_words: pd.DataFrame, hyp_tokens: list): |
|
|
toks_probs, toks_durs = [], [] |
|
|
for _, row in df_words.iterrows(): |
|
|
prob = row["prob"] |
|
|
dur = (row["word_end"] - row["word_start"]) * 1000.0 |
|
|
toks_probs.append(prob) |
|
|
toks_durs.append(dur) |
|
|
|
|
|
L = len(hyp_tokens) |
|
|
if len(toks_probs) >= L: |
|
|
toks_probs = toks_probs[:L] |
|
|
toks_durs = toks_durs[:L] |
|
|
else: |
|
|
pad = L - len(toks_probs) |
|
|
toks_probs += [None]*pad |
|
|
toks_durs += [None]*pad |
|
|
|
|
|
arr = np.array([p for p in toks_probs if p is not None]) |
|
|
if arr.size: |
|
|
low_t = float(np.quantile(arr, 0.15)) |
|
|
high_t = float(np.quantile(arr, 0.70)) |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
|
|
|
asr_token_conf = {i: {"prob": toks_probs[i], "duration_ms": toks_durs[i]} for i in range(L)} |
|
|
return asr_token_conf, low_t, high_t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float, |
|
|
is_short: bool, lev1: bool, duration_ms: float = None, |
|
|
low_t: float = 0.6, high_t: float = 0.9, sbert_lo=0.60): |
|
|
band = "mid" |
|
|
if prob is not None: |
|
|
if prob <= low_t: band = "low" |
|
|
elif prob >= high_t: band = "high" |
|
|
very_short = (duration_ms is not None and duration_ms < 120) |
|
|
|
|
|
if band == "low": |
|
|
if is_short and lev1: return 'ASR error (low p + short+lev1)' |
|
|
if very_short: return 'ASR error (low p + very short)' |
|
|
if sbert_sim >= sbert_lo: return 'ASR error (low p + semantic)' |
|
|
return 'ASR error (low p)' |
|
|
|
|
|
if band == "high": |
|
|
return base_decision |
|
|
|
|
|
return base_decision |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, |
|
|
bert_thresh=0.75, max_bert=0.85): |
|
|
|
|
|
ref_num = to_numeric_value(ref_w) |
|
|
hyp_num = to_numeric_value(hyp_w) |
|
|
if (ref_num is not None) or (hyp_num is not None): |
|
|
if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num): |
|
|
return 'ASR error (numbers equal)' |
|
|
|
|
|
|
|
|
if short_word and lev1: |
|
|
return 'ASR error (short+lev1)' |
|
|
|
|
|
|
|
|
sbert_ok = bert_scores["sbert"] >= 0.80 |
|
|
avg_ok = bert_scores["avg"] >= bert_thresh |
|
|
max_ok = (bert_scores["max"] > max_bert) and sbert_ok |
|
|
disagree = (bert_scores.get("note") == "models_disagree") |
|
|
|
|
|
if not disagree: |
|
|
if ((phon_sim or lev1) and avg_ok) or max_ok: |
|
|
return 'ASR error (semantic/phonetic)' |
|
|
else: |
|
|
if phon_sim or lev1: |
|
|
if sbert_ok and avg_ok: |
|
|
return 'ASR error (semantic/phonetic)' |
|
|
else: |
|
|
if bert_scores["sbert"] >= 0.80: |
|
|
return 'ASR error (semantic)' |
|
|
|
|
|
return 'Memorization error' |
|
|
|
|
|
def classify_alignment_optimized( |
|
|
aligned, ref_tokens, hyp_tokens, |
|
|
bert_thresh=0.75, max_bert=0.85, |
|
|
asr_token_conf=None, low_high=None, |
|
|
replace_budget_tokens=None, |
|
|
guard_note=None |
|
|
): |
|
|
|
|
|
if low_high is None: |
|
|
if asr_token_conf: |
|
|
probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None] |
|
|
if probs: |
|
|
low_t = float(np.quantile(probs, 0.15)) |
|
|
high_t = float(np.quantile(probs, 0.70)) |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
else: |
|
|
low_t, high_t = low_high |
|
|
|
|
|
results, corrected_words = [], [] |
|
|
replaced_count = 0 |
|
|
|
|
|
for entry in aligned: |
|
|
tag = entry['type'] |
|
|
i1, i2 = entry.get('ref_idx', (None, None)) |
|
|
j1, j2 = entry.get('hyp_idx', (None, None)) |
|
|
|
|
|
if tag == 'equal': |
|
|
for ref_w, hyp_w in zip(entry['ref'], entry['hyp']): |
|
|
results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': '', 'used': hyp_w}) |
|
|
corrected_words.append(hyp_w) |
|
|
|
|
|
elif tag in ['replace', 'delete', 'insert']: |
|
|
max_len = max(len(entry['ref']), len(entry['hyp'])) |
|
|
for k in range(max_len): |
|
|
ref_w = entry['ref'][k] if k < len(entry['ref']) else '' |
|
|
hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else '' |
|
|
if not ref_w and not hyp_w: |
|
|
continue |
|
|
|
|
|
|
|
|
phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False |
|
|
lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False |
|
|
bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0} |
|
|
short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6) |
|
|
|
|
|
|
|
|
if ref_w and hyp_w: |
|
|
base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, |
|
|
bert_thresh, max_bert) |
|
|
elif hyp_w == '': |
|
|
base_status = 'Missing (possible omission)' |
|
|
elif ref_w == '': |
|
|
base_status = 'Extra (possible ASR insertion)' |
|
|
else: |
|
|
base_status = 'Undefined Case' |
|
|
|
|
|
|
|
|
word_prob = None; word_dur = None |
|
|
if (j1 is not None) and (j2 is not None): |
|
|
hyp_abs_idx = j1 + k |
|
|
if asr_token_conf and hyp_abs_idx in asr_token_conf: |
|
|
word_prob = asr_token_conf[hyp_abs_idx].get("prob") |
|
|
word_dur = asr_token_conf[hyp_abs_idx].get("duration_ms") |
|
|
|
|
|
final_status = base_status |
|
|
if ref_w and hyp_w: |
|
|
final_status = gate_by_word_conf( |
|
|
base_decision=base_status, prob=word_prob, |
|
|
sbert_sim=bert_scores["sbert"], |
|
|
is_short=short_word, lev1=lev1, |
|
|
duration_ms=word_dur, |
|
|
low_t=low_t, high_t=high_t, sbert_lo=0.60 |
|
|
) |
|
|
|
|
|
|
|
|
used = hyp_w |
|
|
budget_info = "" |
|
|
if ref_w and hyp_w: |
|
|
if final_status.startswith("ASR error"): |
|
|
if (replace_budget_tokens is None) or (replaced_count < replace_budget_tokens): |
|
|
used = ref_w |
|
|
replaced_count += 1 |
|
|
if replace_budget_tokens is not None: |
|
|
budget_info = f", budget={replaced_count}/{replace_budget_tokens}" |
|
|
else: |
|
|
used = hyp_w |
|
|
final_status += " [guard: budget reached]" |
|
|
budget_info = f", budget={replaced_count}/{replace_budget_tokens}" |
|
|
else: |
|
|
used = hyp_w |
|
|
elif hyp_w == '': |
|
|
used = '' |
|
|
elif ref_w == '': |
|
|
used = hyp_w |
|
|
|
|
|
reason = (f'Phonetic={phon_sim}, Lev1={lev1}, ' |
|
|
f'SBERT={bert_scores["sbert"]:.2f}, ' |
|
|
f'MARBERT={bert_scores["marbert"]:.2f}, ' |
|
|
f'MAX={bert_scores["max"]:.2f}, ' |
|
|
f'AVG={bert_scores["avg"]:.2f}, short={short_word}, ' |
|
|
f'prob={None if word_prob is None else round(word_prob,2)}, ' |
|
|
f'dur_ms={None if word_dur is None else int(word_dur)}, ' |
|
|
f'low_t={round(low_t,2)}, high_t={round(high_t,2)}') |
|
|
|
|
|
if bert_scores.get("note"): |
|
|
reason += f", note={bert_scores['note']}" |
|
|
if guard_note: |
|
|
reason += f", guard='{guard_note}'" |
|
|
if budget_info: |
|
|
reason += budget_info |
|
|
|
|
|
results.append({ |
|
|
'ASR_word': hyp_w, 'GT_word': ref_w, |
|
|
'status': final_status, 'reason': reason, 'used': used |
|
|
}) |
|
|
if used: |
|
|
corrected_words.append(used) |
|
|
|
|
|
corrected_text = " ".join([w for w in corrected_words if w]) |
|
|
|
|
|
|
|
|
stats = { |
|
|
"replacements_made": sum(1 for r in results |
|
|
if r.get("used") and r.get("GT_word") and r["used"] == r["GT_word"] |
|
|
and r.get("ASR_word") and r["ASR_word"] != r["GT_word"]), |
|
|
"budget_reached_count": sum(1 for r in results if isinstance(r.get("status"), str) and "budget reached" in r["status"]), |
|
|
"asr_error_count": sum(1 for r in results if isinstance(r.get("status"), str) and r["status"].startswith("ASR error")), |
|
|
"memorization_error_count": sum(1 for r in results if r.get("status") == "Memorization error"), |
|
|
"missing_count": sum(1 for r in results if r.get("status","").startswith("Missing")), |
|
|
"extra_count": sum(1 for r in results if r.get("status","").startswith("Extra")), |
|
|
"total_tokens": len(results) |
|
|
} |
|
|
|
|
|
return results, corrected_text, stats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lcs_len(a, b): |
|
|
m, n = len(a), len(b) |
|
|
dp = [[0]*(n+1) for _ in range(m+1)] |
|
|
for i in range(1, m+1): |
|
|
ai = a[i-1] |
|
|
for j in range(1, n+1): |
|
|
if ai == b[j-1]: |
|
|
dp[i][j] = dp[i-1][j-1] + 1 |
|
|
else: |
|
|
dp[i][j] = dp[i-1][j] if dp[i-1][j] >= dp[i][j-1] else dp[i][j-1] |
|
|
return dp[m][n] |
|
|
|
|
|
def rouge_l_f1_tokens(ref_tokens, hyp_tokens, beta=1.2): |
|
|
if not ref_tokens or not hyp_tokens: |
|
|
return 0.0, 0.0, 0.0 |
|
|
lcs = lcs_len(ref_tokens, hyp_tokens) |
|
|
prec = lcs / len(hyp_tokens) |
|
|
rec = lcs / len(ref_tokens) |
|
|
if prec == 0 and rec == 0: |
|
|
return 0.0, 0.0, 0.0 |
|
|
f1 = ((1+beta**2) * prec * rec) / (rec + beta**2 * prec + 1e-12) |
|
|
return float(f1), float(prec), float(rec) |
|
|
|
|
|
def compute_wer_like(aligned, ref_tokens_len): |
|
|
S = D = I = 0 |
|
|
for op in aligned: |
|
|
if op['type'] == 'replace': |
|
|
S += max(len(op['ref']), len(op['hyp'])) |
|
|
elif op['type'] == 'delete': |
|
|
D += len(op['ref']) |
|
|
elif op['type'] == 'insert': |
|
|
I += len(op['hyp']) |
|
|
N = max(ref_tokens_len, 1) |
|
|
return (S + D + I) / N |
|
|
|
|
|
def global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, sbert_model): |
|
|
sbert_sim_text = float(util.pytorch_cos_sim( |
|
|
sbert_model.encode(_normalize_for_models(original_text), convert_to_tensor=True), |
|
|
sbert_model.encode(_normalize_for_models(asr_text), convert_to_tensor=True) |
|
|
)) |
|
|
|
|
|
rouge_f1, rouge_p, rouge_r = rouge_l_f1_tokens(ref_tokens, hyp_tokens) |
|
|
equal_tokens = sum(len(op['ref']) for op in aligned if op['type'] == 'equal') |
|
|
equal_ratio = equal_tokens / max(len(ref_tokens), 1) |
|
|
wer = compute_wer_like(aligned, len(ref_tokens)) |
|
|
|
|
|
off_topic = ((sbert_sim_text < 0.70 and rouge_f1 < 0.45 and equal_ratio < 0.25) or (wer > 0.65)) |
|
|
|
|
|
L = len(hyp_tokens) |
|
|
if off_topic: |
|
|
budget = 0 |
|
|
elif sbert_sim_text < 0.80 or rouge_f1 < 0.55: |
|
|
budget = int(0.15 * L) |
|
|
else: |
|
|
budget = int(0.40 * L) |
|
|
|
|
|
metrics = { |
|
|
"sbert_sim_text": round(sbert_sim_text, 3), |
|
|
"rougeL_f1": round(rouge_f1, 3), |
|
|
"rougeL_prec": round(rouge_p, 3), |
|
|
"rougeL_rec": round(rouge_r, 3), |
|
|
"equal_ratio": round(equal_ratio, 3), |
|
|
"wer_like": round(wer, 3), |
|
|
} |
|
|
print(f"[GUARD] off_topic={off_topic}, budget={budget}, metrics={metrics}", flush=True) |
|
|
return {"off_topic": off_topic, "budget_tokens": budget, "metrics": metrics} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def literal_similarity(original, recited): |
|
|
def norm(t): |
|
|
t = re.sub(r'[ًٌٍَُِّْـ]', '', t) |
|
|
t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t) |
|
|
t = re.sub(r'\s+', ' ', t).strip() |
|
|
return t |
|
|
o = norm(original); r = norm(recited) |
|
|
lev = textdistance.levenshtein.normalized_similarity(o, r) |
|
|
ot = simple_tokenize(o); rt = simple_tokenize(r) |
|
|
common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2) |
|
|
word_overlap = common / max(len(ot), 1) |
|
|
try: |
|
|
import nltk.translate.bleu_score as bleu |
|
|
bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0 |
|
|
except Exception: |
|
|
bleu1 = 0.0 |
|
|
final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1 |
|
|
return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3), |
|
|
"bleu1": round(bleu1,3), "literal_score": round(final_score,3)} |
|
|
|
|
|
def semantic_similarity(original, recited, use_marbert=FORCE_USE_MARBERT): |
|
|
sbert_sim = float(util.pytorch_cos_sim( |
|
|
_SBERT.encode(_normalize_for_models(original), convert_to_tensor=True), |
|
|
_SBERT.encode(_normalize_for_models(recited), convert_to_tensor=True) |
|
|
)) |
|
|
marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0 |
|
|
return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3), |
|
|
"semantic_score": round(max(sbert_sim, marbert_sim),3)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_audio_path(audio): |
|
|
if isinstance(audio, str): |
|
|
if not os.path.exists(audio): |
|
|
raise FileNotFoundError(f"Audio path not found: {audio}") |
|
|
return audio |
|
|
if isinstance(audio, tuple) and len(audio) == 2: |
|
|
data, sr = audio |
|
|
if isinstance(data, np.ndarray): |
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
|
sf.write(tmp.name, data, sr) |
|
|
return tmp.name |
|
|
raise ValueError("Unsupported audio input format") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_and_evaluate(audio, original_text, whisper_size=None, |
|
|
compute_type=None, vad=True, use_marbert=True): |
|
|
try: |
|
|
if not original_text or not original_text.strip(): |
|
|
raise ValueError("Original text is empty.") |
|
|
|
|
|
|
|
|
whisper_size = FORCE_WHISPER_NAME |
|
|
compute_type = FORCE_COMPUTE_TYPE |
|
|
use_marbert = FORCE_USE_MARBERT |
|
|
|
|
|
print(f"[RUN] whisper={whisper_size}, compute={compute_type}, marbert={use_marbert}", flush=True) |
|
|
|
|
|
load_models(whisper_name=whisper_size, whisper_compute=compute_type, use_marbert=use_marbert) |
|
|
|
|
|
audio_path = ensure_audio_path(audio) |
|
|
print(f"[AUDIO] path={audio_path}", flush=True) |
|
|
|
|
|
segments, info = _WHISPER.transcribe(audio_path, **ASR_OPTS) |
|
|
segments = list(segments) |
|
|
print(f"[ASR] segments={len(segments)}", flush=True) |
|
|
|
|
|
|
|
|
words = [] |
|
|
for seg in segments: |
|
|
for w in (seg.words or []): |
|
|
tok = clean_ar_token(w.word) |
|
|
if tok: |
|
|
words.append(tok) |
|
|
asr_text = " ".join(words) |
|
|
|
|
|
|
|
|
ref_tokens = simple_tokenize(original_text) |
|
|
hyp_tokens = simple_tokenize(asr_text) |
|
|
aligned = align_texts(ref_tokens, hyp_tokens) |
|
|
|
|
|
|
|
|
guard = global_offtopic_guard(original_text, asr_text, ref_tokens, hyp_tokens, aligned, _SBERT) |
|
|
off_topic = guard["off_topic"] |
|
|
guard_metrics = guard["metrics"] |
|
|
|
|
|
if FORCE_BUDGET_MODE == "off": |
|
|
budget_tokens = None |
|
|
guard_note = "budget_off" |
|
|
elif FORCE_BUDGET_MODE == "fixed": |
|
|
budget_tokens = int(FIXED_BUDGET_TOKENS) |
|
|
guard_note = f"budget_fixed_{budget_tokens}" |
|
|
elif FORCE_BUDGET_MODE == "ratio": |
|
|
budget_tokens = int(BUDGET_RATIO * len(hyp_tokens)) |
|
|
guard_note = f"budget_ratio_{BUDGET_RATIO}" |
|
|
else: |
|
|
budget_tokens = guard["budget_tokens"] |
|
|
guard_note = "off-topic" if off_topic else "ok" |
|
|
|
|
|
print(f"[BUDGET] mode={FORCE_BUDGET_MODE}, budget={budget_tokens}, note={guard_note}", flush=True) |
|
|
|
|
|
|
|
|
df_words = extract_word_conf_table(segments) |
|
|
asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens) |
|
|
print(f"[CONF] low_t={low_t:.3f}, high_t={high_t:.3f}", flush=True) |
|
|
|
|
|
|
|
|
results, corrected_text, local_stats = classify_alignment_optimized( |
|
|
aligned, ref_tokens, hyp_tokens, |
|
|
bert_thresh=0.75, max_bert=0.85, |
|
|
asr_token_conf=asr_token_conf, low_high=(low_t, high_t), |
|
|
replace_budget_tokens=budget_tokens, |
|
|
guard_note=guard_note |
|
|
) |
|
|
|
|
|
|
|
|
lit = literal_similarity(original_text, corrected_text) |
|
|
sem = semantic_similarity(original_text, corrected_text, use_marbert=use_marbert) |
|
|
|
|
|
|
|
|
all_probs = df_words["prob"].dropna().tolist() |
|
|
conf_summary = { |
|
|
"num_words_with_prob": int(len(all_probs)), |
|
|
"avg_prob": None if not all_probs else float(np.mean(all_probs)), |
|
|
"p15": None if not all_probs else float(np.quantile(all_probs, 0.15)), |
|
|
"p70": None if not all_probs else float(np.quantile(all_probs, 0.70)), |
|
|
} |
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
|
|
|
report = { |
|
|
"requested": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert}, |
|
|
"effective": {"whisper_model": whisper_size, "compute_type": compute_type, "use_marbert": use_marbert}, |
|
|
"guard": {"mode": FORCE_BUDGET_MODE, "off_topic": off_topic, "budget_tokens": None if budget_tokens is None else int(budget_tokens), **guard_metrics}, |
|
|
"local_stats": local_stats, |
|
|
"confidence_summary": conf_summary, |
|
|
"original_text": original_text, |
|
|
"asr_text": asr_text, |
|
|
"corrected_text": corrected_text, |
|
|
"literal": lit, |
|
|
"semantic": sem, |
|
|
"low_t": float(low_t), "high_t": float(high_t), |
|
|
} |
|
|
return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("ERROR in transcribe_and_evaluate:\n", tb, flush=True) |
|
|
empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}]) |
|
|
err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2) |
|
|
gr.Warning(str(e)) |
|
|
return "", "", err_json, empty_df |
|
|
|
|
|
def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True): |
|
|
corrected_text, asr_text, report_json, df = transcribe_and_evaluate( |
|
|
audio, original_text, whisper_size, compute_type, vad, use_marbert |
|
|
) |
|
|
try: |
|
|
return json.loads(report_json) |
|
|
except Exception: |
|
|
return {"error": "Failed to parse report_json."} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)") |
|
|
|
|
|
with gr.Row(): |
|
|
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio") |
|
|
original = gr.Textbox(lines=8, label="Original Text (Ground Truth)") |
|
|
|
|
|
with gr.Row(): |
|
|
whisper_size = gr.Dropdown(choices=["large-v3"], value="large-v3", label="Whisper model size (forced)") |
|
|
compute_type = gr.Dropdown(choices=["int8"], value="int8", label="compute_type (forced)") |
|
|
vad = gr.Checkbox(value=True, label="VAD filter") |
|
|
use_marbert = gr.Checkbox(value=True, label="Use MARBERT (forced)") |
|
|
|
|
|
btn = gr.Button("Transcribe & Evaluate", variant="primary") |
|
|
|
|
|
corrected = gr.Textbox(lines=6, label="Corrected Transcript (ASR errors restored)") |
|
|
asr_out = gr.Textbox(lines=6, label="Raw ASR Transcript") |
|
|
report = gr.JSON(label="Report (scores & thresholds)") |
|
|
|
|
|
table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"], |
|
|
label="Token-level Decisions", wrap=True) |
|
|
|
|
|
btn.click( |
|
|
fn=transcribe_and_evaluate, |
|
|
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], |
|
|
outputs=[corrected, asr_out, report, table], |
|
|
api_name="evaluate" |
|
|
) |
|
|
|
|
|
gr.Button(visible=False).click( |
|
|
fn=api_predict, |
|
|
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], |
|
|
outputs=gr.JSON(), |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = build_ui() |
|
|
demo.launch() |
|
|
|