|
|
import os, re, json, math, tempfile, traceback |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import torch |
|
|
import textdistance |
|
|
|
|
|
import gradio as gr |
|
|
from faster_whisper import WhisperModel |
|
|
|
|
|
from sentence_transformers import SentenceTransformer, util |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
CPU_MODE = (DEVICE != "cuda") |
|
|
|
|
|
|
|
|
DEFAULT_WHISPER_CPU = "small" |
|
|
DEFAULT_COMPUTE_CPU = "int8" |
|
|
DEFAULT_USE_MARBERT_CPU = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_SBERT = None |
|
|
_MARBERT_TOK = None |
|
|
_MARBERT = None |
|
|
_WHISPER = None |
|
|
|
|
|
def load_models( |
|
|
sbert_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
|
marbert_name="UBC-NLP/MARBERT", |
|
|
whisper_name="small", |
|
|
whisper_compute="int8" |
|
|
): |
|
|
"""Load models only once.""" |
|
|
global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER |
|
|
|
|
|
|
|
|
if CPU_MODE: |
|
|
whisper_name = DEFAULT_WHISPER_CPU |
|
|
whisper_compute = DEFAULT_COMPUTE_CPU |
|
|
|
|
|
if _SBERT is None: |
|
|
_SBERT = SentenceTransformer(sbert_name, device=DEVICE) |
|
|
|
|
|
|
|
|
if _MARBERT is None and (not CPU_MODE): |
|
|
_MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name) |
|
|
_MARBERT = AutoModel.from_pretrained(marbert_name).to(DEVICE) |
|
|
_MARBERT.eval() |
|
|
|
|
|
if _WHISPER is None: |
|
|
_WHISPER = WhisperModel(whisper_name, device=DEVICE, compute_type=whisper_compute) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def normalize_ar_orth(text: str) -> str: |
|
|
text = re.sub(r"[ًٌٍَُِّْـ]", "", text) |
|
|
text = re.sub(r"[“”\"',:؛؟.!()\[\]{}،\-–—_]", " ", text) |
|
|
text = re.sub(r"[إأٱآا]", "ا", text) |
|
|
text = text.replace("ة", "ه").replace("ى", "ي") |
|
|
text = re.sub(r"\s+", " ", text).strip() |
|
|
return text |
|
|
|
|
|
def simple_tokenize(text: str): |
|
|
"""يحاول punkt؛ وإن فشل يستخدم تجزئة بسيطة بالمسافات.""" |
|
|
t = normalize_ar_orth(text) |
|
|
try: |
|
|
import nltk |
|
|
try: |
|
|
nltk.data.find('tokenizers/punkt') |
|
|
except LookupError: |
|
|
nltk.download('punkt', quiet=True) |
|
|
return nltk.word_tokenize(t) |
|
|
except Exception: |
|
|
return t.split() |
|
|
|
|
|
def align_texts(ref_tokens, hyp_tokens): |
|
|
import difflib |
|
|
sm = difflib.SequenceMatcher(None, ref_tokens, hyp_tokens) |
|
|
aligned = [] |
|
|
for tag, i1, i2, j1, j2 in sm.get_opcodes(): |
|
|
aligned.append({ |
|
|
'type': tag, |
|
|
'ref': ref_tokens[i1:i2], |
|
|
'hyp': hyp_tokens[j1:j2], |
|
|
'ref_idx': (i1, i2), |
|
|
'hyp_idx': (j1, j2) |
|
|
}) |
|
|
return aligned |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def arabic_soundex(word): |
|
|
w = normalize_ar_orth(word) |
|
|
groups = { |
|
|
'b': 'بف', 'j': 'جشص', 'd': 'دض', 't': 'طت', 'q': 'قغ', 'k': 'كخ', |
|
|
's': 'سصز', 'z': 'ثذظ', 'h': 'ح', 'a': 'ع', 'm': 'م', 'n': 'ن', |
|
|
'l': 'ل', 'r': 'ر', 'w': 'و', 'y': 'ي' |
|
|
} |
|
|
code = [] |
|
|
for ch in w: |
|
|
for rep, chars in groups.items(): |
|
|
if ch in chars: |
|
|
code.append(rep); break |
|
|
return "".join(code) |
|
|
|
|
|
def phonetic_similarity(w1, w2): |
|
|
if not w1 or not w2: return False |
|
|
return arabic_soundex(w1) == arabic_soundex(w2) |
|
|
|
|
|
def is_levenshtein_1(w1, w2): |
|
|
if not w1 or not w2: return False |
|
|
return textdistance.levenshtein(w1, w2) == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789") |
|
|
UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2, |
|
|
"ثلاث":3,"ثلاثه":3,"ثلاثة":3,"اربع":4,"اربعه":4,"أربع":4,"أربعه":4, |
|
|
"خمس":5,"خمسه":5,"ست":6,"سته":6,"سبع":7,"سبعه":7,"ثمان":8,"ثماني":8,"ثمانيه":8, |
|
|
"تسع":9,"تسعه":9} |
|
|
TENS = {"عشر":10,"عشرة":10,"عشره":10,"عشرون":20,"عشرين":20,"ثلاثون":30,"ثلاثين":30, |
|
|
"اربعون":40,"أربعون":40,"اربعين":40,"خمسون":50,"ستون":60,"سبعون":70,"ثمانون":80,"تسعون":90} |
|
|
HUND = {"مئه":100,"مئة":100,"مائه":100} |
|
|
SCALE = {"الف":1000,"ألف":1000,"آلاف":1000,"مليون":10**6,"مليار":10**9} |
|
|
|
|
|
def normalize_digits(s: str) -> str: |
|
|
return s.translate(AR_DIGITS) |
|
|
|
|
|
def words_to_number(tokens): |
|
|
total = 0; current = 0 |
|
|
for w in tokens: |
|
|
if w in UNITS: current += UNITS[w] |
|
|
elif w in TENS: current += TENS[w] |
|
|
elif w in HUND: current += HUND[w] |
|
|
elif w in SCALE: |
|
|
current = max(1, current) * SCALE[w] |
|
|
total += current; current = 0 |
|
|
elif w == "و": |
|
|
continue |
|
|
else: |
|
|
total += current; current = 0 |
|
|
total += current |
|
|
return total if total != 0 else None |
|
|
|
|
|
def to_numeric_value(token: str): |
|
|
if not token: return None |
|
|
t = normalize_ar_orth(token) |
|
|
d = normalize_digits(t) |
|
|
if re.fullmatch(r"\d+", d): |
|
|
return int(d) |
|
|
toks = t.split() |
|
|
return words_to_number(toks) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def marbert_cls_similarity(a: str, b: str) -> float: |
|
|
if not a or not b: return 0.0 |
|
|
if _MARBERT is None: |
|
|
return 0.0 |
|
|
with torch.no_grad(): |
|
|
ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(DEVICE) |
|
|
tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(DEVICE) |
|
|
ea = _MARBERT(**ta).last_hidden_state[:,0,:] |
|
|
eb = _MARBERT(**tb).last_hidden_state[:,0,:] |
|
|
sim = util.cos_sim(ea, eb).item() |
|
|
return (sim + 1) / 2 |
|
|
|
|
|
def multi_bert_similarity(a: str, b: str): |
|
|
if not a or not b: |
|
|
return {"sbert":0.0, "marbert":0.0, "max":0.0, "avg":0.0} |
|
|
sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(a, convert_to_tensor=True), |
|
|
_SBERT.encode(b, convert_to_tensor=True))) |
|
|
marbert_sim = marbert_cls_similarity(a, b) |
|
|
vals = [sbert_sim, marbert_sim] |
|
|
return {"sbert": sbert_sim, "marbert": marbert_sim, "max": max(vals), "avg": sum(vals)/len(vals)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_ar_token(t: str) -> str: |
|
|
t = t.strip() |
|
|
t = re.sub(r'^[^\w\u0600-\u06FF]+|[^\w\u0600-\u06FF]+$', '', t) |
|
|
t = normalize_ar_orth(t) |
|
|
return t |
|
|
|
|
|
def extract_word_conf_table(segments): |
|
|
rows = [] |
|
|
for seg in segments: |
|
|
for w in (seg.words or []): |
|
|
rows.append({ |
|
|
"seg_start": float(seg.start), |
|
|
"seg_end": float(seg.end), |
|
|
"word_start": float(w.start), |
|
|
"word_end": float(w.end), |
|
|
"word": clean_ar_token(w.word), |
|
|
"prob": float(w.probability), |
|
|
}) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
def build_asr_token_conf(df_words: pd.DataFrame, hyp_tokens: list): |
|
|
toks_probs, toks_durs = [], [] |
|
|
for _, row in df_words.iterrows(): |
|
|
prob = row["prob"] |
|
|
dur = (row["word_end"] - row["word_start"]) * 1000.0 |
|
|
toks_probs.append(prob) |
|
|
toks_durs.append(dur) |
|
|
|
|
|
L = len(hyp_tokens) |
|
|
if len(toks_probs) >= L: |
|
|
toks_probs = toks_probs[:L] |
|
|
toks_durs = toks_durs[:L] |
|
|
else: |
|
|
pad = L - len(toks_probs) |
|
|
toks_probs += [None]*pad |
|
|
toks_durs += [None]*pad |
|
|
|
|
|
arr = np.array([p for p in toks_probs if p is not None]) |
|
|
if arr.size: |
|
|
low_t = float(np.quantile(arr, 0.15)) |
|
|
high_t = float(np.quantile(arr, 0.70)) |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
|
|
|
asr_token_conf = {i: {"prob": toks_probs[i], "duration_ms": toks_durs[i]} for i in range(L)} |
|
|
return asr_token_conf, low_t, high_t |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float, |
|
|
is_short: bool, lev1: bool, duration_ms: float = None, |
|
|
low_t: float = 0.6, high_t: float = 0.9, sbert_lo=0.60): |
|
|
band = "mid" |
|
|
if prob is not None: |
|
|
if prob <= low_t: band = "low" |
|
|
elif prob >= high_t: band = "high" |
|
|
very_short = (duration_ms is not None and duration_ms < 120) |
|
|
|
|
|
if band == "low": |
|
|
if is_short and lev1: return 'ASR error (low p + short+lev1)' |
|
|
if very_short: return 'ASR error (low p + very short)' |
|
|
if sbert_sim >= sbert_lo: return 'ASR error (low p + semantic)' |
|
|
return 'ASR error (low p)' |
|
|
|
|
|
if band == "high": |
|
|
return base_decision |
|
|
|
|
|
return base_decision |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, |
|
|
bert_thresh=0.75, max_bert=0.85): |
|
|
ref_num = to_numeric_value(ref_w) |
|
|
hyp_num = to_numeric_value(hyp_w) |
|
|
if (ref_num is not None) or (hyp_num is not None): |
|
|
if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num): |
|
|
return 'ASR error (numbers equal)' |
|
|
if short_word and lev1: |
|
|
return 'ASR error (short+lev1)' |
|
|
avg_ok = bert_scores["avg"] >= bert_thresh |
|
|
max_ok = bert_scores["max"] > max_bert |
|
|
if ((phon_sim or lev1) and avg_ok) or max_ok: |
|
|
return 'ASR error (semantic/phonetic)' |
|
|
return 'Memorization error' |
|
|
|
|
|
def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens, |
|
|
bert_thresh=0.75, max_bert=0.85, |
|
|
asr_token_conf=None, low_high=None): |
|
|
if low_high is None: |
|
|
if asr_token_conf: |
|
|
probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None] |
|
|
if probs: |
|
|
low_t = float(np.quantile(probs, 0.15)) |
|
|
high_t = float(np.quantile(probs, 0.70)) |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
else: |
|
|
low_t, high_t = 0.5, 0.85 |
|
|
else: |
|
|
low_t, high_t = low_high |
|
|
|
|
|
results, corrected_words = [], [] |
|
|
|
|
|
for entry in aligned: |
|
|
tag = entry['type'] |
|
|
i1, i2 = entry.get('ref_idx', (None,None)) |
|
|
j1, j2 = entry.get('hyp_idx', (None,None)) |
|
|
|
|
|
if tag == 'equal': |
|
|
for ref_w, hyp_w in zip(entry['ref'], entry['hyp']): |
|
|
results.append({'ASR_word': hyp_w, 'GT_word': ref_w, 'status': 'Correct', 'reason': ''}) |
|
|
corrected_words.append(hyp_w) |
|
|
elif tag in ['replace', 'delete', 'insert']: |
|
|
max_len = max(len(entry['ref']), len(entry['hyp'])) |
|
|
for k in range(max_len): |
|
|
ref_w = entry['ref'][k] if k < len(entry['ref']) else '' |
|
|
hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else '' |
|
|
if not ref_w and not hyp_w: |
|
|
continue |
|
|
|
|
|
phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False |
|
|
lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False |
|
|
bert_scores = multi_bert_similarity(ref_w, hyp_w) if ref_w and hyp_w else {"sbert":0,"marbert":0,"max":0,"avg":0} |
|
|
short_word = bool(ref_w and hyp_w and max(len(ref_w), len(hyp_w)) <= 6) |
|
|
|
|
|
if ref_w and hyp_w: |
|
|
base_status = classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word, |
|
|
bert_thresh, max_bert) |
|
|
elif hyp_w == '': |
|
|
base_status = 'Missing (possible omission)' |
|
|
elif ref_w == '': |
|
|
base_status = 'Extra (possible ASR insertion)' |
|
|
else: |
|
|
base_status = 'Undefined Case' |
|
|
|
|
|
word_prob = None; word_dur = None |
|
|
if (j1 is not None) and (j2 is not None): |
|
|
hyp_abs_idx = j1 + k |
|
|
if asr_token_conf and hyp_abs_idx in asr_token_conf: |
|
|
word_prob = asr_token_conf[hyp_abs_idx].get("prob") |
|
|
word_dur = asr_token_conf[hyp_abs_idx].get("duration_ms") |
|
|
|
|
|
final_status = base_status |
|
|
if ref_w and hyp_w: |
|
|
final_status = gate_by_word_conf( |
|
|
base_decision=base_status, prob=word_prob, |
|
|
sbert_sim=bert_scores["sbert"], |
|
|
is_short=short_word, lev1=lev1, |
|
|
duration_ms=word_dur, |
|
|
low_t=low_t, high_t=high_t, sbert_lo=0.60 |
|
|
) |
|
|
|
|
|
used = hyp_w |
|
|
if ref_w and hyp_w: |
|
|
used = ref_w if final_status.startswith("ASR error") else hyp_w |
|
|
elif hyp_w == '': |
|
|
used = '' |
|
|
elif ref_w == '': |
|
|
used = hyp_w |
|
|
|
|
|
reason = (f'Phonetic={phon_sim}, Lev1={lev1}, ' |
|
|
f'SBERT={bert_scores["sbert"]:.2f}, ' |
|
|
f'MARBERT={bert_scores["marbert"]:.2f}, ' |
|
|
f'MAX={bert_scores["max"]:.2f}, ' |
|
|
f'AVG={bert_scores["avg"]:.2f}, short={short_word}, ' |
|
|
f'prob={None if word_prob is None else round(word_prob,2)}, ' |
|
|
f'dur_ms={None if word_dur is None else int(word_dur)}, ' |
|
|
f'low_t={round(low_t,2)}, high_t={round(high_t,2)}') |
|
|
|
|
|
results.append({'ASR_word': hyp_w, 'GT_word': ref_w, |
|
|
'status': final_status, 'reason': reason, 'used': used}) |
|
|
if used: |
|
|
corrected_words.append(used) |
|
|
|
|
|
corrected_text = " ".join([w for w in corrected_words if w]) |
|
|
return results, corrected_text |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def literal_similarity(original, recited): |
|
|
def norm(t): |
|
|
t = re.sub(r'[ًٌٍَُِّْـ]', '', t) |
|
|
t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t) |
|
|
t = re.sub(r'\s+', ' ', t).strip() |
|
|
return t |
|
|
o = norm(original); r = norm(recited) |
|
|
lev = textdistance.levenshtein.normalized_similarity(o, r) |
|
|
ot = simple_tokenize(o); rt = simple_tokenize(r) |
|
|
common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2) |
|
|
word_overlap = common / max(len(ot), 1) |
|
|
try: |
|
|
import nltk.translate.bleu_score as bleu |
|
|
bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0 |
|
|
except Exception: |
|
|
bleu1 = 0.0 |
|
|
final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1 |
|
|
return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3), |
|
|
"bleu1": round(bleu1,3), "literal_score": round(final_score,3)} |
|
|
|
|
|
def semantic_similarity(original, recited, use_marbert=True): |
|
|
sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True), |
|
|
_SBERT.encode(recited, convert_to_tensor=True))) |
|
|
marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0 |
|
|
return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3), |
|
|
"semantic_score": round(max(sbert_sim, marbert_sim),3)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_audio_path(audio): |
|
|
"""Accepts filepath (str) OR (numpy_array, sr). Returns a valid filepath.""" |
|
|
if isinstance(audio, str): |
|
|
if not os.path.exists(audio): |
|
|
raise FileNotFoundError(f"Audio path not found: {audio}") |
|
|
return audio |
|
|
if isinstance(audio, tuple) and len(audio) == 2: |
|
|
data, sr = audio |
|
|
if isinstance(data, np.ndarray): |
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
|
|
sf.write(tmp.name, data, sr) |
|
|
return tmp.name |
|
|
raise ValueError("Unsupported audio input format") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def transcribe_and_evaluate(audio, original_text, whisper_size=None, |
|
|
compute_type=None, vad=True, use_marbert=True): |
|
|
try: |
|
|
if not original_text or not original_text.strip(): |
|
|
raise ValueError("Original text is empty.") |
|
|
|
|
|
|
|
|
if CPU_MODE: |
|
|
whisper_size = DEFAULT_WHISPER_CPU |
|
|
compute_type = DEFAULT_COMPUTE_CPU |
|
|
use_marbert = DEFAULT_USE_MARBERT_CPU |
|
|
else: |
|
|
whisper_size = whisper_size or "large-v3" |
|
|
compute_type = compute_type or "float16" |
|
|
|
|
|
load_models(whisper_name=whisper_size, whisper_compute=compute_type) |
|
|
|
|
|
audio_path = ensure_audio_path(audio) |
|
|
segments, info = _WHISPER.transcribe( |
|
|
audio_path, word_timestamps=True, |
|
|
vad_filter=vad, vad_parameters={"min_silence_duration_ms": 200} |
|
|
) |
|
|
segments = list(segments) |
|
|
|
|
|
words = [] |
|
|
for seg in segments: |
|
|
for w in (seg.words or []): |
|
|
tok = clean_ar_token(w.word) |
|
|
if tok: words.append(tok) |
|
|
asr_text = " ".join(words) |
|
|
|
|
|
ref_tokens = simple_tokenize(original_text) |
|
|
hyp_tokens = simple_tokenize(asr_text) |
|
|
aligned = align_texts(ref_tokens, hyp_tokens) |
|
|
|
|
|
df_words = extract_word_conf_table(segments) |
|
|
asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens) |
|
|
|
|
|
results, corrected_text = classify_alignment_optimized( |
|
|
aligned, ref_tokens, hyp_tokens, |
|
|
bert_thresh=0.75, max_bert=0.85, |
|
|
asr_token_conf=asr_token_conf, low_high=(low_t, high_t) |
|
|
) |
|
|
|
|
|
lit = literal_similarity(original_text, corrected_text) |
|
|
sem = semantic_similarity(original_text, corrected_text, use_marbert=(use_marbert and not CPU_MODE)) |
|
|
|
|
|
df = pd.DataFrame(results) |
|
|
|
|
|
report = { |
|
|
"whisper_model": whisper_size, |
|
|
"compute_type": compute_type, |
|
|
"original_text": original_text, |
|
|
"asr_text": asr_text, |
|
|
"corrected_text": corrected_text, |
|
|
"literal": lit, |
|
|
"semantic": sem, |
|
|
"low_t": low_t, "high_t": high_t, |
|
|
} |
|
|
return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df |
|
|
|
|
|
except Exception as e: |
|
|
tb = traceback.format_exc() |
|
|
print("ERROR in transcribe_and_evaluate:\n", tb, flush=True) |
|
|
|
|
|
empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}]) |
|
|
err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2) |
|
|
gr.Warning(str(e)) |
|
|
return "", "", err_json, empty_df |
|
|
|
|
|
def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True): |
|
|
|
|
|
corrected_text, asr_text, report_json, df = transcribe_and_evaluate( |
|
|
audio, original_text, whisper_size, compute_type, vad, use_marbert |
|
|
) |
|
|
try: |
|
|
return json.loads(report_json) |
|
|
except Exception: |
|
|
return {"error": "Failed to parse report_json."} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_ui(): |
|
|
with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio") |
|
|
original = gr.Textbox(lines=8, label="Original Text (Ground Truth)") |
|
|
|
|
|
with gr.Row(): |
|
|
whisper_size = gr.Dropdown( |
|
|
choices=["tiny","base","small","medium","large-v3"], |
|
|
value=("large-v3" if not CPU_MODE else DEFAULT_WHISPER_CPU), |
|
|
label="Whisper model size" |
|
|
) |
|
|
compute_type = gr.Dropdown( |
|
|
choices=["int8", "int8_float16", "float16", "float32"], |
|
|
value=("float16" if not CPU_MODE else DEFAULT_COMPUTE_CPU), |
|
|
label="compute_type" |
|
|
) |
|
|
vad = gr.Checkbox(value=True, label="VAD filter") |
|
|
use_marbert = gr.Checkbox(value=(not CPU_MODE), label="Use MARBERT (semantic)") |
|
|
|
|
|
btn = gr.Button("Transcribe & Evaluate", variant="primary") |
|
|
|
|
|
corrected = gr.Textbox(lines=6, label="Corrected Transcript (ASR errors restored)") |
|
|
asr_out = gr.Textbox(lines=6, label="Raw ASR Transcript") |
|
|
report = gr.JSON(label="Report (scores & thresholds)") |
|
|
|
|
|
table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"], |
|
|
label="Token-level Decisions", wrap=True) |
|
|
|
|
|
btn.click( |
|
|
fn=transcribe_and_evaluate, |
|
|
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], |
|
|
outputs=[corrected, asr_out, report, table], |
|
|
api_name="evaluate" |
|
|
) |
|
|
|
|
|
gr.Button(visible=False).click( |
|
|
fn=api_predict, |
|
|
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert], |
|
|
outputs=gr.JSON(), |
|
|
api_name="predict" |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo = build_ui() |
|
|
demo.launch() |
|
|
|