Mjo98's picture
add more examples & update reqs
bc98c80
import os
import re
import math
import json
import unicodedata
from functools import lru_cache
import numpy as np
import pandas as pd
import gradio as gr
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
from arabert import ArabertPreprocessor
# ===== Constants =====
SAUDI_HG_MODEL = "xmjo/arabic-eou-model-v1"
SAUDI_ONNX_FILENAME = "model_quantized.onnx"
SAUDI_REVISION = "onnx"
digit_map = {
'0': 'صفر', '1': 'واحد', '2': 'اثنين', '3': 'ثلاثة', '4': 'أربعة',
'5': 'خمسة', '6': 'ستة', '7': 'سبعة', '8': 'ثمانية', '9': 'تسعة'
}
# ===== Utilities =====
def log_odds(p, eps=0.0):
return np.log(p / (1 - p + eps))
# ===== Model Runner =====
class SaudiModelRunner:
def __init__(self):
print(f"Loading model {SAUDI_HG_MODEL}...")
self.model_id = SAUDI_HG_MODEL
self.revision = SAUDI_REVISION
# Download model
try:
model_path = hf_hub_download(
repo_id=SAUDI_HG_MODEL,
filename=SAUDI_ONNX_FILENAME,
revision=SAUDI_REVISION
)
except Exception as e:
print(f"Error downloading model: {e}")
raise e
# Init ONNX session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = 4
sess_options.inter_op_num_threads = 1
self.session = ort.InferenceSession(
model_path,
providers=["CPUExecutionProvider"],
sess_options=sess_options
)
# Tokenizer & Preprocessor
self.tokenizer = AutoTokenizer.from_pretrained(SAUDI_HG_MODEL, revision=SAUDI_REVISION)
self.preprocessor = ArabertPreprocessor("aubmindlab/bert-base-arabertv02-twitter")
# Threshold from plugin
self.thresh = 0.685
def normalize_arabic(self, text: str) -> str:
# Logic from turn_detector_plugin.py
# 1. Basic normalization (plugin calls self._normalize_text(text) which usually does NFKC and lower)
text = unicodedata.normalize("NFKC", text.lower())
# 2. Regex replacements
text = re.sub(r"[\[\]\(\)\{\}<>.،,؟?!«»\"'“”‘’\-—_]", " ", text)
text = re.sub(r"\s+", " ", text).strip()
# 3. Digit mapping
text = ''.join(digit_map.get(ch, ch) for ch in text)
# 4. Arabic specific
text = re.sub(r'[\u064B-\u065F\u0670]', '', text)
text = re.sub(r'[أإآ]', 'ا', text)
text = re.sub(r'ة', 'ه', text)
text = re.sub(r'ى', 'ي', text)
# 5. Arabert Preprocessor
text = self.preprocessor.preprocess(text)
return text
def _run_inference(self, text):
inputs = self.tokenizer(
text,
return_tensors="np",
truncation=True,
max_length=128
)
feed_dict = {
"input_ids": inputs["input_ids"].astype("int64"),
"attention_mask": inputs["attention_mask"].astype("int64"),
}
if "token_type_ids" in inputs:
feed_dict["token_type_ids"] = inputs["token_type_ids"].astype("int64")
else:
feed_dict["token_type_ids"] = np.zeros_like(inputs["input_ids"], dtype=np.int64)
outputs = self.session.run(None, feed_dict)
logits = outputs[0]
# Softmax
exp_logits = np.exp(logits - np.max(logits))
probs = exp_logits / exp_logits.sum(axis=-1, keepdims=True)
return probs[0][1] # EOU probability
def predict_eou_scores(self, text: str):
# Normalize
norm_text = self.normalize_arabic(text)
# Split into tokens (space separated after Arabert preprocessing)
tokens = norm_text.split()
results = []
current_text = ""
# Prefix loop to simulate streaming/turn detection at each point
for token in tokens:
if current_text:
current_text += " " + token
else:
current_text = token
prob = self._run_inference(current_text)
results.append((token, prob))
return pd.DataFrame(results, columns=["token", "pred"])
def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str:
EPS = 1e-12
thresh = self.thresh
_df = df.copy()
_df.token = _df.token.replace({"\n": "⏎", " ": "␠"})
_df["log_odds"] = (
_df.pred.fillna(thresh)
.add(EPS)
.apply(log_odds).sub(log_odds(thresh))
.mask(_df.pred.isna())
)
_df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int)
vmin, vmax = _df.log_odds.min(), _df.log_odds.max()
vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0
fmt = (
_df.drop(columns=["pred"])
.style
.bar(
subset=["log_odds"],
align="zero",
vmin=-vmax_abs,
vmax=vmax_abs,
cmap=cmap,
height=70,
width=100,
)
.text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs)
.format(na_rep="", precision=1, subset=["log_odds"])
.format("{:3d}", subset=["Prob(EoT) as %"])
.hide(axis="index")
)
return fmt.to_html()
def generate_highlighted_text(self, text: str):
"""Returns: (highlighted_list, styled_html) for Gradio"""
eps = 1e-12
threshold = self.thresh
if not text:
return [], "<div>No input.</div>"
df = self.predict_eou_scores(text)
df["score"] = (
df.pred.fillna(threshold)
.add(eps)
.apply(log_odds).sub(log_odds(threshold))
.mask(df.pred.isna() | df.pred.round(2).eq(0))
)
max_abs_score = df["score"].abs().max()
if pd.notna(max_abs_score) and max_abs_score > 0:
df.score = df.score / (max_abs_score * 1.5)
styled_df = self.make_styled_df(df[["token", "pred"]])
return list(zip(df.token, df.score)), styled_df
# ===== Cached Loader =====
@lru_cache(maxsize=1)
def get_runner():
return SaudiModelRunner()
# ===== Gradio App =====
def run_model(text: str):
runner = get_runner()
ht, html = runner.generate_highlighted_text(text)
return ht, html
EXAMPLES = [
["كيف حالك بشرنا عنك عساك بخير"],
["رقم جوالي صفر خمسة سبعة ستة ستة واحد ثلاثة سبعة صفر صفر"],
["او صخره صلبه تستخدم كاساس للمبنى وقال ان الزعماء الدينيين سيرفضون"],
["هل يمكنك أن تخبرني عن"],
["جمهورية الدومينيكان هي دولة تقع في الكاريبي على جزيرة هيسبانيولا التي تشترك فيها مع هايتي"],
]
with gr.Blocks(theme="soft", title="Arabic Turn Detector Debugger") as demo:
gr.Markdown(
"""# Arabic Turn Detector Debugger
Visualize predicted turn endings from **Arabic EOU Model**.
Red ⇒ agent should reply • Blue ⇒ agent should wait"""
)
with gr.Row():
text_in = gr.Textbox(
label="Input Text",
info="Enter Arabic text to analyze.",
value=EXAMPLES[0][0],
lines=4,
text_align="right",
rtl=True
)
gr.Examples(
examples=EXAMPLES,
inputs=[text_in],
label="Examples"
)
run_btn = gr.Button("Run Analysis", variant="primary")
with gr.Row():
with gr.Column():
out_ht = gr.HighlightedText(
label="EoT Predictions",
color_map="coolwarm",
scale=1.5,
rtl=True
)
out_html = gr.HTML(label="Raw scores")
run_btn.click(
fn=run_model,
inputs=[text_in],
outputs=[out_ht, out_html]
)
if __name__ == "__main__":
demo.launch(share=True)