Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import os, re, math,
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import torch
|
|
@@ -9,12 +9,22 @@ from faster_whisper import WhisperModel
|
|
| 9 |
|
| 10 |
from sentence_transformers import SentenceTransformer, util
|
| 11 |
from transformers import AutoTokenizer, AutoModel
|
|
|
|
| 12 |
|
| 13 |
# =========================
|
| 14 |
-
# Device &
|
| 15 |
# =========================
|
| 16 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
| 18 |
_SBERT = None
|
| 19 |
_MARBERT_TOK = None
|
| 20 |
_MARBERT = None
|
|
@@ -26,13 +36,23 @@ def load_models(
|
|
| 26 |
whisper_name="small",
|
| 27 |
whisper_compute="int8"
|
| 28 |
):
|
|
|
|
| 29 |
global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
if _SBERT is None:
|
| 31 |
_SBERT = SentenceTransformer(sbert_name, device=DEVICE)
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
_MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
|
| 34 |
_MARBERT = AutoModel.from_pretrained(marbert_name).to(DEVICE)
|
| 35 |
_MARBERT.eval()
|
|
|
|
| 36 |
if _WHISPER is None:
|
| 37 |
_WHISPER = WhisperModel(whisper_name, device=DEVICE, compute_type=whisper_compute)
|
| 38 |
|
|
@@ -48,12 +68,17 @@ def normalize_ar_orth(text: str) -> str:
|
|
| 48 |
return text
|
| 49 |
|
| 50 |
def simple_tokenize(text: str):
|
| 51 |
-
|
|
|
|
| 52 |
try:
|
| 53 |
-
nltk
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
def align_texts(ref_tokens, hyp_tokens):
|
| 59 |
import difflib
|
|
@@ -95,7 +120,7 @@ def is_levenshtein_1(w1, w2):
|
|
| 95 |
return textdistance.levenshtein(w1, w2) == 1
|
| 96 |
|
| 97 |
# =========================
|
| 98 |
-
# Numbers
|
| 99 |
# =========================
|
| 100 |
AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")
|
| 101 |
UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2,
|
|
@@ -136,10 +161,12 @@ def to_numeric_value(token: str):
|
|
| 136 |
return words_to_number(toks)
|
| 137 |
|
| 138 |
# =========================
|
| 139 |
-
# Semantic similarities
|
| 140 |
# =========================
|
| 141 |
def marbert_cls_similarity(a: str, b: str) -> float:
|
| 142 |
if not a or not b: return 0.0
|
|
|
|
|
|
|
| 143 |
with torch.no_grad():
|
| 144 |
ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
| 145 |
tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
|
@@ -231,20 +258,17 @@ def gate_by_word_conf(base_decision: str, prob: float, sbert_sim: float,
|
|
| 231 |
return base_decision
|
| 232 |
|
| 233 |
# =========================
|
| 234 |
-
# Pair
|
| 235 |
# =========================
|
| 236 |
def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
|
| 237 |
bert_thresh=0.75, max_bert=0.85):
|
| 238 |
-
# 1) numbers
|
| 239 |
ref_num = to_numeric_value(ref_w)
|
| 240 |
hyp_num = to_numeric_value(hyp_w)
|
| 241 |
if (ref_num is not None) or (hyp_num is not None):
|
| 242 |
if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
|
| 243 |
return 'ASR error (numbers equal)'
|
| 244 |
-
# 2) short + lev1
|
| 245 |
if short_word and lev1:
|
| 246 |
return 'ASR error (short+lev1)'
|
| 247 |
-
# 3) semantic
|
| 248 |
avg_ok = bert_scores["avg"] >= bert_thresh
|
| 249 |
max_ok = bert_scores["max"] > max_bert
|
| 250 |
if ((phon_sim or lev1) and avg_ok) or max_ok:
|
|
@@ -254,7 +278,6 @@ def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
|
|
| 254 |
def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
| 255 |
bert_thresh=0.75, max_bert=0.85,
|
| 256 |
asr_token_conf=None, low_high=None):
|
| 257 |
-
# thresholds
|
| 258 |
if low_high is None:
|
| 259 |
if asr_token_conf:
|
| 260 |
probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
|
|
@@ -268,8 +291,7 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
|
| 268 |
else:
|
| 269 |
low_t, high_t = low_high
|
| 270 |
|
| 271 |
-
results = []
|
| 272 |
-
corrected_words = []
|
| 273 |
|
| 274 |
for entry in aligned:
|
| 275 |
tag = entry['type']
|
|
@@ -285,7 +307,8 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
|
| 285 |
for k in range(max_len):
|
| 286 |
ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
|
| 287 |
hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
|
| 288 |
-
if not ref_w and not hyp_w:
|
|
|
|
| 289 |
|
| 290 |
phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
|
| 291 |
lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
|
|
@@ -302,7 +325,6 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
|
| 302 |
else:
|
| 303 |
base_status = 'Undefined Case'
|
| 304 |
|
| 305 |
-
# word-level confidence
|
| 306 |
word_prob = None; word_dur = None
|
| 307 |
if (j1 is not None) and (j2 is not None):
|
| 308 |
hyp_abs_idx = j1 + k
|
|
@@ -320,14 +342,13 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
|
| 320 |
low_t=low_t, high_t=high_t, sbert_lo=0.60
|
| 321 |
)
|
| 322 |
|
|
|
|
| 323 |
if ref_w and hyp_w:
|
| 324 |
used = ref_w if final_status.startswith("ASR error") else hyp_w
|
| 325 |
elif hyp_w == '':
|
| 326 |
used = ''
|
| 327 |
elif ref_w == '':
|
| 328 |
used = hyp_w
|
| 329 |
-
else:
|
| 330 |
-
used = hyp_w
|
| 331 |
|
| 332 |
reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
|
| 333 |
f'SBERT={bert_scores["sbert"]:.2f}, '
|
|
@@ -347,15 +368,9 @@ def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
|
| 347 |
return results, corrected_text
|
| 348 |
|
| 349 |
# =========================
|
| 350 |
-
#
|
| 351 |
# =========================
|
| 352 |
def literal_similarity(original, recited):
|
| 353 |
-
import nltk
|
| 354 |
-
try:
|
| 355 |
-
nltk.data.find('tokenizers/punkt')
|
| 356 |
-
except LookupError:
|
| 357 |
-
nltk.download('punkt')
|
| 358 |
-
|
| 359 |
def norm(t):
|
| 360 |
t = re.sub(r'[ًٌٍَُِّْـ]', '', t)
|
| 361 |
t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t)
|
|
@@ -363,41 +378,33 @@ def literal_similarity(original, recited):
|
|
| 363 |
return t
|
| 364 |
o = norm(original); r = norm(recited)
|
| 365 |
lev = textdistance.levenshtein.normalized_similarity(o, r)
|
| 366 |
-
ot =
|
| 367 |
common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
|
| 368 |
word_overlap = common / max(len(ot), 1)
|
| 369 |
-
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
| 371 |
final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
|
| 372 |
return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
|
| 373 |
"bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
|
| 374 |
|
| 375 |
-
def semantic_similarity(original, recited):
|
| 376 |
sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True),
|
| 377 |
_SBERT.encode(recited, convert_to_tensor=True)))
|
| 378 |
-
|
| 379 |
-
ta = _MARBERT_TOK(original, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
| 380 |
-
tb = _MARBERT_TOK(recited, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
| 381 |
-
ea = _MARBERT(**ta).last_hidden_state[:,0,:]
|
| 382 |
-
eb = _MARBERT(**tb).last_hidden_state[:,0,:]
|
| 383 |
-
sim = util.cos_sim(ea, eb).item()
|
| 384 |
-
marbert_sim = (sim + 1)/2
|
| 385 |
return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
|
| 386 |
"semantic_score": round(max(sbert_sim, marbert_sim),3)}
|
| 387 |
|
| 388 |
# =========================
|
| 389 |
-
# Audio input helper
|
| 390 |
# =========================
|
| 391 |
-
import soundfile as sf
|
| 392 |
-
|
| 393 |
def ensure_audio_path(audio):
|
| 394 |
-
"""
|
| 395 |
-
Accepts:
|
| 396 |
-
- str (filepath)
|
| 397 |
-
- tuple (numpy_array, sample_rate) if Gradio Audio type='numpy'
|
| 398 |
-
Returns a filepath suitable for faster-whisper.
|
| 399 |
-
"""
|
| 400 |
if isinstance(audio, str):
|
|
|
|
|
|
|
| 401 |
return audio
|
| 402 |
if isinstance(audio, tuple) and len(audio) == 2:
|
| 403 |
data, sr = audio
|
|
@@ -408,79 +415,87 @@ def ensure_audio_path(audio):
|
|
| 408 |
raise ValueError("Unsupported audio input format")
|
| 409 |
|
| 410 |
# =========================
|
| 411 |
-
#
|
| 412 |
# =========================
|
| 413 |
-
def transcribe_and_evaluate(audio, original_text, whisper_size=
|
| 414 |
-
compute_type=
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
| 426 |
|
| 427 |
-
|
| 428 |
-
words = []
|
| 429 |
-
for seg in segments:
|
| 430 |
-
for w in (seg.words or []):
|
| 431 |
-
tok = clean_ar_token(w.word)
|
| 432 |
-
if tok: words.append(tok)
|
| 433 |
-
asr_text = " ".join(words)
|
| 434 |
-
|
| 435 |
-
# Tokens + align
|
| 436 |
-
ref_tokens = simple_tokenize(original_text)
|
| 437 |
-
hyp_tokens = simple_tokenize(asr_text)
|
| 438 |
-
aligned = align_texts(ref_tokens, hyp_tokens)
|
| 439 |
-
|
| 440 |
-
# Word confidence map
|
| 441 |
-
df_words = extract_word_conf_table(segments)
|
| 442 |
-
asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
|
| 443 |
-
|
| 444 |
-
# Classify + corrected text
|
| 445 |
-
results, corrected_text = classify_alignment_optimized(
|
| 446 |
-
aligned, ref_tokens, hyp_tokens,
|
| 447 |
-
bert_thresh=0.75, max_bert=0.85,
|
| 448 |
-
asr_token_conf=asr_token_conf, low_high=(low_t, high_t)
|
| 449 |
-
)
|
| 450 |
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
"
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
corrected_text, asr_text, report_json, df = transcribe_and_evaluate(
|
| 481 |
audio, original_text, whisper_size, compute_type, vad, use_marbert
|
| 482 |
)
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
# =========================
|
| 486 |
# Gradio UI
|
|
@@ -488,23 +503,25 @@ def api_predict(audio, original_text, whisper_size="small",
|
|
| 488 |
def build_ui():
|
| 489 |
with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo:
|
| 490 |
gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)")
|
|
|
|
| 491 |
with gr.Row():
|
|
|
|
| 492 |
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
|
| 493 |
original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
|
| 494 |
|
| 495 |
with gr.Row():
|
| 496 |
whisper_size = gr.Dropdown(
|
| 497 |
choices=["tiny","base","small","medium","large-v3"],
|
| 498 |
-
value=("large-v3" if
|
| 499 |
label="Whisper model size"
|
| 500 |
)
|
| 501 |
compute_type = gr.Dropdown(
|
| 502 |
choices=["int8", "int8_float16", "float16", "float32"],
|
| 503 |
-
value=("float16" if
|
| 504 |
label="compute_type"
|
| 505 |
)
|
| 506 |
vad = gr.Checkbox(value=True, label="VAD filter")
|
| 507 |
-
use_marbert = gr.Checkbox(value=(
|
| 508 |
|
| 509 |
btn = gr.Button("Transcribe & Evaluate", variant="primary")
|
| 510 |
|
|
@@ -515,20 +532,18 @@ def build_ui():
|
|
| 515 |
table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"],
|
| 516 |
label="Token-level Decisions", wrap=True)
|
| 517 |
|
| 518 |
-
# UI action + API endpoint
|
| 519 |
btn.click(
|
| 520 |
fn=transcribe_and_evaluate,
|
| 521 |
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
|
| 522 |
outputs=[corrected, asr_out, report, table],
|
| 523 |
-
api_name="evaluate"
|
| 524 |
)
|
| 525 |
|
| 526 |
-
# JSON-only endpoint (hidden button)
|
| 527 |
gr.Button(visible=False).click(
|
| 528 |
fn=api_predict,
|
| 529 |
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
|
| 530 |
outputs=gr.JSON(),
|
| 531 |
-
api_name="predict"
|
| 532 |
)
|
| 533 |
|
| 534 |
return demo
|
|
|
|
| 1 |
+
import os, re, json, math, tempfile, traceback
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import torch
|
|
|
|
| 9 |
|
| 10 |
from sentence_transformers import SentenceTransformer, util
|
| 11 |
from transformers import AutoTokenizer, AutoModel
|
| 12 |
+
import soundfile as sf
|
| 13 |
|
| 14 |
# =========================
|
| 15 |
+
# Device & global config
|
| 16 |
# =========================
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
CPU_MODE = (DEVICE != "cuda")
|
| 19 |
+
|
| 20 |
+
# أمان الذاكرة على CPU
|
| 21 |
+
DEFAULT_WHISPER_CPU = "small"
|
| 22 |
+
DEFAULT_COMPUTE_CPU = "int8"
|
| 23 |
+
DEFAULT_USE_MARBERT_CPU = False
|
| 24 |
|
| 25 |
+
# =========================
|
| 26 |
+
# Lazy models
|
| 27 |
+
# =========================
|
| 28 |
_SBERT = None
|
| 29 |
_MARBERT_TOK = None
|
| 30 |
_MARBERT = None
|
|
|
|
| 36 |
whisper_name="small",
|
| 37 |
whisper_compute="int8"
|
| 38 |
):
|
| 39 |
+
"""Load models only once."""
|
| 40 |
global _SBERT, _MARBERT_TOK, _MARBERT, _WHISPER
|
| 41 |
+
|
| 42 |
+
# حماية على CPU: اجبار نماذج أخف
|
| 43 |
+
if CPU_MODE:
|
| 44 |
+
whisper_name = DEFAULT_WHISPER_CPU
|
| 45 |
+
whisper_compute = DEFAULT_COMPUTE_CPU
|
| 46 |
+
|
| 47 |
if _SBERT is None:
|
| 48 |
_SBERT = SentenceTransformer(sbert_name, device=DEVICE)
|
| 49 |
+
|
| 50 |
+
# حمّل MARBERT فقط عند الحاجة (قد يستهلك RAM)
|
| 51 |
+
if _MARBERT is None and (not CPU_MODE):
|
| 52 |
_MARBERT_TOK = AutoTokenizer.from_pretrained(marbert_name)
|
| 53 |
_MARBERT = AutoModel.from_pretrained(marbert_name).to(DEVICE)
|
| 54 |
_MARBERT.eval()
|
| 55 |
+
|
| 56 |
if _WHISPER is None:
|
| 57 |
_WHISPER = WhisperModel(whisper_name, device=DEVICE, compute_type=whisper_compute)
|
| 58 |
|
|
|
|
| 68 |
return text
|
| 69 |
|
| 70 |
def simple_tokenize(text: str):
|
| 71 |
+
"""يحاول punkt؛ وإن فشل يستخدم تجزئة بسيطة بالمسافات."""
|
| 72 |
+
t = normalize_ar_orth(text)
|
| 73 |
try:
|
| 74 |
+
import nltk
|
| 75 |
+
try:
|
| 76 |
+
nltk.data.find('tokenizers/punkt')
|
| 77 |
+
except LookupError:
|
| 78 |
+
nltk.download('punkt', quiet=True)
|
| 79 |
+
return nltk.word_tokenize(t)
|
| 80 |
+
except Exception:
|
| 81 |
+
return t.split()
|
| 82 |
|
| 83 |
def align_texts(ref_tokens, hyp_tokens):
|
| 84 |
import difflib
|
|
|
|
| 120 |
return textdistance.levenshtein(w1, w2) == 1
|
| 121 |
|
| 122 |
# =========================
|
| 123 |
+
# Numbers
|
| 124 |
# =========================
|
| 125 |
AR_DIGITS = str.maketrans("٠١٢٣٤٥٦٧٨٩", "0123456789")
|
| 126 |
UNITS = {"صفر":0,"واحد":1,"واحدة":1,"اثنان":2,"اثنين":2,"اثنتان":2,"اثنتين":2,
|
|
|
|
| 161 |
return words_to_number(toks)
|
| 162 |
|
| 163 |
# =========================
|
| 164 |
+
# Semantic similarities
|
| 165 |
# =========================
|
| 166 |
def marbert_cls_similarity(a: str, b: str) -> float:
|
| 167 |
if not a or not b: return 0.0
|
| 168 |
+
if _MARBERT is None:
|
| 169 |
+
return 0.0
|
| 170 |
with torch.no_grad():
|
| 171 |
ta = _MARBERT_TOK(a, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
| 172 |
tb = _MARBERT_TOK(b, return_tensors='pt', truncation=True, padding=True).to(DEVICE)
|
|
|
|
| 258 |
return base_decision
|
| 259 |
|
| 260 |
# =========================
|
| 261 |
+
# Pair + main classifiers
|
| 262 |
# =========================
|
| 263 |
def classify_pair(ref_w, hyp_w, bert_scores, phon_sim, lev1, short_word,
|
| 264 |
bert_thresh=0.75, max_bert=0.85):
|
|
|
|
| 265 |
ref_num = to_numeric_value(ref_w)
|
| 266 |
hyp_num = to_numeric_value(hyp_w)
|
| 267 |
if (ref_num is not None) or (hyp_num is not None):
|
| 268 |
if (ref_num is not None) and (hyp_num is not None) and (ref_num == hyp_num):
|
| 269 |
return 'ASR error (numbers equal)'
|
|
|
|
| 270 |
if short_word and lev1:
|
| 271 |
return 'ASR error (short+lev1)'
|
|
|
|
| 272 |
avg_ok = bert_scores["avg"] >= bert_thresh
|
| 273 |
max_ok = bert_scores["max"] > max_bert
|
| 274 |
if ((phon_sim or lev1) and avg_ok) or max_ok:
|
|
|
|
| 278 |
def classify_alignment_optimized(aligned, ref_tokens, hyp_tokens,
|
| 279 |
bert_thresh=0.75, max_bert=0.85,
|
| 280 |
asr_token_conf=None, low_high=None):
|
|
|
|
| 281 |
if low_high is None:
|
| 282 |
if asr_token_conf:
|
| 283 |
probs = [v["prob"] for v in asr_token_conf.values() if v["prob"] is not None]
|
|
|
|
| 291 |
else:
|
| 292 |
low_t, high_t = low_high
|
| 293 |
|
| 294 |
+
results, corrected_words = [], []
|
|
|
|
| 295 |
|
| 296 |
for entry in aligned:
|
| 297 |
tag = entry['type']
|
|
|
|
| 307 |
for k in range(max_len):
|
| 308 |
ref_w = entry['ref'][k] if k < len(entry['ref']) else ''
|
| 309 |
hyp_w = entry['hyp'][k] if k < len(entry['hyp']) else ''
|
| 310 |
+
if not ref_w and not hyp_w:
|
| 311 |
+
continue
|
| 312 |
|
| 313 |
phon_sim = phonetic_similarity(ref_w, hyp_w) if ref_w and hyp_w else False
|
| 314 |
lev1 = is_levenshtein_1(ref_w, hyp_w) if ref_w and hyp_w else False
|
|
|
|
| 325 |
else:
|
| 326 |
base_status = 'Undefined Case'
|
| 327 |
|
|
|
|
| 328 |
word_prob = None; word_dur = None
|
| 329 |
if (j1 is not None) and (j2 is not None):
|
| 330 |
hyp_abs_idx = j1 + k
|
|
|
|
| 342 |
low_t=low_t, high_t=high_t, sbert_lo=0.60
|
| 343 |
)
|
| 344 |
|
| 345 |
+
used = hyp_w
|
| 346 |
if ref_w and hyp_w:
|
| 347 |
used = ref_w if final_status.startswith("ASR error") else hyp_w
|
| 348 |
elif hyp_w == '':
|
| 349 |
used = ''
|
| 350 |
elif ref_w == '':
|
| 351 |
used = hyp_w
|
|
|
|
|
|
|
| 352 |
|
| 353 |
reason = (f'Phonetic={phon_sim}, Lev1={lev1}, '
|
| 354 |
f'SBERT={bert_scores["sbert"]:.2f}, '
|
|
|
|
| 368 |
return results, corrected_text
|
| 369 |
|
| 370 |
# =========================
|
| 371 |
+
# Scores
|
| 372 |
# =========================
|
| 373 |
def literal_similarity(original, recited):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
def norm(t):
|
| 375 |
t = re.sub(r'[ًٌٍَُِّْـ]', '', t)
|
| 376 |
t = re.sub(r'[“”",:؛؟.!()\[\]{}،\-–—_]', ' ', t)
|
|
|
|
| 378 |
return t
|
| 379 |
o = norm(original); r = norm(recited)
|
| 380 |
lev = textdistance.levenshtein.normalized_similarity(o, r)
|
| 381 |
+
ot = simple_tokenize(o); rt = simple_tokenize(r)
|
| 382 |
common = sum(1 for w1, w2 in zip(ot, rt) if w1 == w2)
|
| 383 |
word_overlap = common / max(len(ot), 1)
|
| 384 |
+
try:
|
| 385 |
+
import nltk.translate.bleu_score as bleu
|
| 386 |
+
bleu1 = bleu.sentence_bleu([ot], rt, weights=(1,0,0,0)) if (ot and rt) else 0.0
|
| 387 |
+
except Exception:
|
| 388 |
+
bleu1 = 0.0
|
| 389 |
final_score = 0.5*lev + 0.3*word_overlap + 0.2*bleu1
|
| 390 |
return {"levenshtein": round(lev,3), "word_overlap": round(word_overlap,3),
|
| 391 |
"bleu1": round(bleu1,3), "literal_score": round(final_score,3)}
|
| 392 |
|
| 393 |
+
def semantic_similarity(original, recited, use_marbert=True):
|
| 394 |
sbert_sim = float(util.pytorch_cos_sim(_SBERT.encode(original, convert_to_tensor=True),
|
| 395 |
_SBERT.encode(recited, convert_to_tensor=True)))
|
| 396 |
+
marbert_sim = marbert_cls_similarity(original, recited) if use_marbert else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
return {"sbert_sim": round(sbert_sim,3), "marbert_sim": round(marbert_sim,3),
|
| 398 |
"semantic_score": round(max(sbert_sim, marbert_sim),3)}
|
| 399 |
|
| 400 |
# =========================
|
| 401 |
+
# Audio input helper
|
| 402 |
# =========================
|
|
|
|
|
|
|
| 403 |
def ensure_audio_path(audio):
|
| 404 |
+
"""Accepts filepath (str) OR (numpy_array, sr). Returns a valid filepath."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
if isinstance(audio, str):
|
| 406 |
+
if not os.path.exists(audio):
|
| 407 |
+
raise FileNotFoundError(f"Audio path not found: {audio}")
|
| 408 |
return audio
|
| 409 |
if isinstance(audio, tuple) and len(audio) == 2:
|
| 410 |
data, sr = audio
|
|
|
|
| 415 |
raise ValueError("Unsupported audio input format")
|
| 416 |
|
| 417 |
# =========================
|
| 418 |
+
# Pipeline (with robust error reporting)
|
| 419 |
# =========================
|
| 420 |
+
def transcribe_and_evaluate(audio, original_text, whisper_size=None,
|
| 421 |
+
compute_type=None, vad=True, use_marbert=True):
|
| 422 |
+
try:
|
| 423 |
+
if not original_text or not original_text.strip():
|
| 424 |
+
raise ValueError("Original text is empty.")
|
| 425 |
+
|
| 426 |
+
# Defaults per device
|
| 427 |
+
if CPU_MODE:
|
| 428 |
+
whisper_size = DEFAULT_WHISPER_CPU
|
| 429 |
+
compute_type = DEFAULT_COMPUTE_CPU
|
| 430 |
+
use_marbert = DEFAULT_USE_MARBERT_CPU
|
| 431 |
+
else:
|
| 432 |
+
whisper_size = whisper_size or "large-v3"
|
| 433 |
+
compute_type = compute_type or "float16"
|
| 434 |
|
| 435 |
+
load_models(whisper_name=whisper_size, whisper_compute=compute_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
|
| 437 |
+
audio_path = ensure_audio_path(audio)
|
| 438 |
+
segments, info = _WHISPER.transcribe(
|
| 439 |
+
audio_path, word_timestamps=True,
|
| 440 |
+
vad_filter=vad, vad_parameters={"min_silence_duration_ms": 200}
|
| 441 |
+
)
|
| 442 |
+
segments = list(segments)
|
| 443 |
+
|
| 444 |
+
words = []
|
| 445 |
+
for seg in segments:
|
| 446 |
+
for w in (seg.words or []):
|
| 447 |
+
tok = clean_ar_token(w.word)
|
| 448 |
+
if tok: words.append(tok)
|
| 449 |
+
asr_text = " ".join(words)
|
| 450 |
+
|
| 451 |
+
ref_tokens = simple_tokenize(original_text)
|
| 452 |
+
hyp_tokens = simple_tokenize(asr_text)
|
| 453 |
+
aligned = align_texts(ref_tokens, hyp_tokens)
|
| 454 |
+
|
| 455 |
+
df_words = extract_word_conf_table(segments)
|
| 456 |
+
asr_token_conf, low_t, high_t = build_asr_token_conf(df_words, hyp_tokens)
|
| 457 |
+
|
| 458 |
+
results, corrected_text = classify_alignment_optimized(
|
| 459 |
+
aligned, ref_tokens, hyp_tokens,
|
| 460 |
+
bert_thresh=0.75, max_bert=0.85,
|
| 461 |
+
asr_token_conf=asr_token_conf, low_high=(low_t, high_t)
|
| 462 |
+
)
|
| 463 |
|
| 464 |
+
lit = literal_similarity(original_text, corrected_text)
|
| 465 |
+
sem = semantic_similarity(original_text, corrected_text, use_marbert=(use_marbert and not CPU_MODE))
|
| 466 |
+
|
| 467 |
+
df = pd.DataFrame(results)
|
| 468 |
+
|
| 469 |
+
report = {
|
| 470 |
+
"whisper_model": whisper_size,
|
| 471 |
+
"compute_type": compute_type,
|
| 472 |
+
"original_text": original_text,
|
| 473 |
+
"asr_text": asr_text,
|
| 474 |
+
"corrected_text": corrected_text,
|
| 475 |
+
"literal": lit,
|
| 476 |
+
"semantic": sem,
|
| 477 |
+
"low_t": low_t, "high_t": high_t,
|
| 478 |
+
}
|
| 479 |
+
return corrected_text, asr_text, json.dumps(report, ensure_ascii=False, indent=2), df
|
| 480 |
+
|
| 481 |
+
except Exception as e:
|
| 482 |
+
tb = traceback.format_exc()
|
| 483 |
+
print("ERROR in transcribe_and_evaluate:\n", tb, flush=True)
|
| 484 |
+
# نرجع JSON بالخطأ بدل ما نفجّر الواجهة
|
| 485 |
+
empty_df = pd.DataFrame([{"ASR_word":"","GT_word":"","status":"ERROR","reason":str(e),"used":""}])
|
| 486 |
+
err_json = json.dumps({"error": str(e), "traceback": tb}, ensure_ascii=False, indent=2)
|
| 487 |
+
gr.Warning(str(e))
|
| 488 |
+
return "", "", err_json, empty_df
|
| 489 |
+
|
| 490 |
+
def api_predict(audio, original_text, whisper_size=None, compute_type=None, vad=True, use_marbert=True):
|
| 491 |
+
# نفس الدالة لكن ترجع JSON فقط
|
| 492 |
corrected_text, asr_text, report_json, df = transcribe_and_evaluate(
|
| 493 |
audio, original_text, whisper_size, compute_type, vad, use_marbert
|
| 494 |
)
|
| 495 |
+
try:
|
| 496 |
+
return json.loads(report_json)
|
| 497 |
+
except Exception:
|
| 498 |
+
return {"error": "Failed to parse report_json."}
|
| 499 |
|
| 500 |
# =========================
|
| 501 |
# Gradio UI
|
|
|
|
| 503 |
def build_ui():
|
| 504 |
with gr.Blocks(title="Samaali ASR Post-Processing", theme=gr.themes.Soft()) as demo:
|
| 505 |
gr.Markdown("## Samaali — ASR Post-Processing (Whisper + Alignment + Confidence + Semantics)")
|
| 506 |
+
|
| 507 |
with gr.Row():
|
| 508 |
+
# filepath أسلم للـ Spaces
|
| 509 |
audio = gr.Audio(sources=["microphone","upload"], type="filepath", label="Audio")
|
| 510 |
original = gr.Textbox(lines=8, label="Original Text (Ground Truth)")
|
| 511 |
|
| 512 |
with gr.Row():
|
| 513 |
whisper_size = gr.Dropdown(
|
| 514 |
choices=["tiny","base","small","medium","large-v3"],
|
| 515 |
+
value=("large-v3" if not CPU_MODE else DEFAULT_WHISPER_CPU),
|
| 516 |
label="Whisper model size"
|
| 517 |
)
|
| 518 |
compute_type = gr.Dropdown(
|
| 519 |
choices=["int8", "int8_float16", "float16", "float32"],
|
| 520 |
+
value=("float16" if not CPU_MODE else DEFAULT_COMPUTE_CPU),
|
| 521 |
label="compute_type"
|
| 522 |
)
|
| 523 |
vad = gr.Checkbox(value=True, label="VAD filter")
|
| 524 |
+
use_marbert = gr.Checkbox(value=(not CPU_MODE), label="Use MARBERT (semantic)")
|
| 525 |
|
| 526 |
btn = gr.Button("Transcribe & Evaluate", variant="primary")
|
| 527 |
|
|
|
|
| 532 |
table = gr.Dataframe(headers=["ASR_word","GT_word","status","reason","used"],
|
| 533 |
label="Token-level Decisions", wrap=True)
|
| 534 |
|
|
|
|
| 535 |
btn.click(
|
| 536 |
fn=transcribe_and_evaluate,
|
| 537 |
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
|
| 538 |
outputs=[corrected, asr_out, report, table],
|
| 539 |
+
api_name="evaluate"
|
| 540 |
)
|
| 541 |
|
|
|
|
| 542 |
gr.Button(visible=False).click(
|
| 543 |
fn=api_predict,
|
| 544 |
inputs=[audio, original, whisper_size, compute_type, vad, use_marbert],
|
| 545 |
outputs=gr.JSON(),
|
| 546 |
+
api_name="predict"
|
| 547 |
)
|
| 548 |
|
| 549 |
return demo
|