import os import re import math import json import unicodedata from functools import lru_cache import numpy as np import pandas as pd import gradio as gr import onnxruntime as ort from transformers import AutoTokenizer from huggingface_hub import hf_hub_download from arabert import ArabertPreprocessor # ===== Constants ===== SAUDI_HG_MODEL = "xmjo/arabic-eou-model-v1" SAUDI_ONNX_FILENAME = "model_quantized.onnx" SAUDI_REVISION = "onnx" digit_map = { '0': 'صفر', '1': 'واحد', '2': 'اثنين', '3': 'ثلاثة', '4': 'أربعة', '5': 'خمسة', '6': 'ستة', '7': 'سبعة', '8': 'ثمانية', '9': 'تسعة' } # ===== Utilities ===== def log_odds(p, eps=0.0): return np.log(p / (1 - p + eps)) # ===== Model Runner ===== class SaudiModelRunner: def __init__(self): print(f"Loading model {SAUDI_HG_MODEL}...") self.model_id = SAUDI_HG_MODEL self.revision = SAUDI_REVISION # Download model try: model_path = hf_hub_download( repo_id=SAUDI_HG_MODEL, filename=SAUDI_ONNX_FILENAME, revision=SAUDI_REVISION ) except Exception as e: print(f"Error downloading model: {e}") raise e # Init ONNX session sess_options = ort.SessionOptions() sess_options.intra_op_num_threads = 4 sess_options.inter_op_num_threads = 1 self.session = ort.InferenceSession( model_path, providers=["CPUExecutionProvider"], sess_options=sess_options ) # Tokenizer & Preprocessor self.tokenizer = AutoTokenizer.from_pretrained(SAUDI_HG_MODEL, revision=SAUDI_REVISION) self.preprocessor = ArabertPreprocessor("aubmindlab/bert-base-arabertv02-twitter") # Threshold from plugin self.thresh = 0.685 def normalize_arabic(self, text: str) -> str: # Logic from turn_detector_plugin.py # 1. Basic normalization (plugin calls self._normalize_text(text) which usually does NFKC and lower) text = unicodedata.normalize("NFKC", text.lower()) # 2. Regex replacements text = re.sub(r"[\[\]\(\)\{\}<>.،,؟?!«»\"'“”‘’\-—_]", " ", text) text = re.sub(r"\s+", " ", text).strip() # 3. Digit mapping text = ''.join(digit_map.get(ch, ch) for ch in text) # 4. Arabic specific text = re.sub(r'[\u064B-\u065F\u0670]', '', text) text = re.sub(r'[أإآ]', 'ا', text) text = re.sub(r'ة', 'ه', text) text = re.sub(r'ى', 'ي', text) # 5. Arabert Preprocessor text = self.preprocessor.preprocess(text) return text def _run_inference(self, text): inputs = self.tokenizer( text, return_tensors="np", truncation=True, max_length=128 ) feed_dict = { "input_ids": inputs["input_ids"].astype("int64"), "attention_mask": inputs["attention_mask"].astype("int64"), } if "token_type_ids" in inputs: feed_dict["token_type_ids"] = inputs["token_type_ids"].astype("int64") else: feed_dict["token_type_ids"] = np.zeros_like(inputs["input_ids"], dtype=np.int64) outputs = self.session.run(None, feed_dict) logits = outputs[0] # Softmax exp_logits = np.exp(logits - np.max(logits)) probs = exp_logits / exp_logits.sum(axis=-1, keepdims=True) return probs[0][1] # EOU probability def predict_eou_scores(self, text: str): # Normalize norm_text = self.normalize_arabic(text) # Split into tokens (space separated after Arabert preprocessing) tokens = norm_text.split() results = [] current_text = "" # Prefix loop to simulate streaming/turn detection at each point for token in tokens: if current_text: current_text += " " + token else: current_text = token prob = self._run_inference(current_text) results.append((token, prob)) return pd.DataFrame(results, columns=["token", "pred"]) def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str: EPS = 1e-12 thresh = self.thresh _df = df.copy() _df.token = _df.token.replace({"\n": "⏎", " ": "␠"}) _df["log_odds"] = ( _df.pred.fillna(thresh) .add(EPS) .apply(log_odds).sub(log_odds(thresh)) .mask(_df.pred.isna()) ) _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int) vmin, vmax = _df.log_odds.min(), _df.log_odds.max() vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0 fmt = ( _df.drop(columns=["pred"]) .style .bar( subset=["log_odds"], align="zero", vmin=-vmax_abs, vmax=vmax_abs, cmap=cmap, height=70, width=100, ) .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) .format(na_rep="", precision=1, subset=["log_odds"]) .format("{:3d}", subset=["Prob(EoT) as %"]) .hide(axis="index") ) return fmt.to_html() def generate_highlighted_text(self, text: str): """Returns: (highlighted_list, styled_html) for Gradio""" eps = 1e-12 threshold = self.thresh if not text: return [], "