Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import math | |
| import json | |
| import unicodedata | |
| from functools import lru_cache | |
| import numpy as np | |
| import pandas as pd | |
| import gradio as gr | |
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| from arabert import ArabertPreprocessor | |
| # ===== Constants ===== | |
| SAUDI_HG_MODEL = "xmjo/arabic-eou-model-v1" | |
| SAUDI_ONNX_FILENAME = "model_quantized.onnx" | |
| SAUDI_REVISION = "onnx" | |
| digit_map = { | |
| '0': 'صفر', '1': 'واحد', '2': 'اثنين', '3': 'ثلاثة', '4': 'أربعة', | |
| '5': 'خمسة', '6': 'ستة', '7': 'سبعة', '8': 'ثمانية', '9': 'تسعة' | |
| } | |
| # ===== Utilities ===== | |
| def log_odds(p, eps=0.0): | |
| return np.log(p / (1 - p + eps)) | |
| # ===== Model Runner ===== | |
| class SaudiModelRunner: | |
| def __init__(self): | |
| print(f"Loading model {SAUDI_HG_MODEL}...") | |
| self.model_id = SAUDI_HG_MODEL | |
| self.revision = SAUDI_REVISION | |
| # Download model | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=SAUDI_HG_MODEL, | |
| filename=SAUDI_ONNX_FILENAME, | |
| revision=SAUDI_REVISION | |
| ) | |
| except Exception as e: | |
| print(f"Error downloading model: {e}") | |
| raise e | |
| # Init ONNX session | |
| sess_options = ort.SessionOptions() | |
| sess_options.intra_op_num_threads = 4 | |
| sess_options.inter_op_num_threads = 1 | |
| self.session = ort.InferenceSession( | |
| model_path, | |
| providers=["CPUExecutionProvider"], | |
| sess_options=sess_options | |
| ) | |
| # Tokenizer & Preprocessor | |
| self.tokenizer = AutoTokenizer.from_pretrained(SAUDI_HG_MODEL, revision=SAUDI_REVISION) | |
| self.preprocessor = ArabertPreprocessor("aubmindlab/bert-base-arabertv02-twitter") | |
| # Threshold from plugin | |
| self.thresh = 0.685 | |
| def normalize_arabic(self, text: str) -> str: | |
| # Logic from turn_detector_plugin.py | |
| # 1. Basic normalization (plugin calls self._normalize_text(text) which usually does NFKC and lower) | |
| text = unicodedata.normalize("NFKC", text.lower()) | |
| # 2. Regex replacements | |
| text = re.sub(r"[\[\]\(\)\{\}<>.،,؟?!«»\"'“”‘’\-—_]", " ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| # 3. Digit mapping | |
| text = ''.join(digit_map.get(ch, ch) for ch in text) | |
| # 4. Arabic specific | |
| text = re.sub(r'[\u064B-\u065F\u0670]', '', text) | |
| text = re.sub(r'[أإآ]', 'ا', text) | |
| text = re.sub(r'ة', 'ه', text) | |
| text = re.sub(r'ى', 'ي', text) | |
| # 5. Arabert Preprocessor | |
| text = self.preprocessor.preprocess(text) | |
| return text | |
| def _run_inference(self, text): | |
| inputs = self.tokenizer( | |
| text, | |
| return_tensors="np", | |
| truncation=True, | |
| max_length=128 | |
| ) | |
| feed_dict = { | |
| "input_ids": inputs["input_ids"].astype("int64"), | |
| "attention_mask": inputs["attention_mask"].astype("int64"), | |
| } | |
| if "token_type_ids" in inputs: | |
| feed_dict["token_type_ids"] = inputs["token_type_ids"].astype("int64") | |
| else: | |
| feed_dict["token_type_ids"] = np.zeros_like(inputs["input_ids"], dtype=np.int64) | |
| outputs = self.session.run(None, feed_dict) | |
| logits = outputs[0] | |
| # Softmax | |
| exp_logits = np.exp(logits - np.max(logits)) | |
| probs = exp_logits / exp_logits.sum(axis=-1, keepdims=True) | |
| return probs[0][1] # EOU probability | |
| def predict_eou_scores(self, text: str): | |
| # Normalize | |
| norm_text = self.normalize_arabic(text) | |
| # Split into tokens (space separated after Arabert preprocessing) | |
| tokens = norm_text.split() | |
| results = [] | |
| current_text = "" | |
| # Prefix loop to simulate streaming/turn detection at each point | |
| for token in tokens: | |
| if current_text: | |
| current_text += " " + token | |
| else: | |
| current_text = token | |
| prob = self._run_inference(current_text) | |
| results.append((token, prob)) | |
| return pd.DataFrame(results, columns=["token", "pred"]) | |
| def make_styled_df(self, df: pd.DataFrame, cmap="coolwarm") -> str: | |
| EPS = 1e-12 | |
| thresh = self.thresh | |
| _df = df.copy() | |
| _df.token = _df.token.replace({"\n": "⏎", " ": "␠"}) | |
| _df["log_odds"] = ( | |
| _df.pred.fillna(thresh) | |
| .add(EPS) | |
| .apply(log_odds).sub(log_odds(thresh)) | |
| .mask(_df.pred.isna()) | |
| ) | |
| _df["Prob(EoT) as %"] = _df.pred.mul(100).fillna(0).astype(int) | |
| vmin, vmax = _df.log_odds.min(), _df.log_odds.max() | |
| vmax_abs = max(abs(vmin), abs(vmax)) * 1.5 if pd.notna(vmin) and pd.notna(vmax) else 1.0 | |
| fmt = ( | |
| _df.drop(columns=["pred"]) | |
| .style | |
| .bar( | |
| subset=["log_odds"], | |
| align="zero", | |
| vmin=-vmax_abs, | |
| vmax=vmax_abs, | |
| cmap=cmap, | |
| height=70, | |
| width=100, | |
| ) | |
| .text_gradient(subset=["log_odds"], cmap=cmap, vmin=-vmax_abs, vmax=vmax_abs) | |
| .format(na_rep="", precision=1, subset=["log_odds"]) | |
| .format("{:3d}", subset=["Prob(EoT) as %"]) | |
| .hide(axis="index") | |
| ) | |
| return fmt.to_html() | |
| def generate_highlighted_text(self, text: str): | |
| """Returns: (highlighted_list, styled_html) for Gradio""" | |
| eps = 1e-12 | |
| threshold = self.thresh | |
| if not text: | |
| return [], "<div>No input.</div>" | |
| df = self.predict_eou_scores(text) | |
| df["score"] = ( | |
| df.pred.fillna(threshold) | |
| .add(eps) | |
| .apply(log_odds).sub(log_odds(threshold)) | |
| .mask(df.pred.isna() | df.pred.round(2).eq(0)) | |
| ) | |
| max_abs_score = df["score"].abs().max() | |
| if pd.notna(max_abs_score) and max_abs_score > 0: | |
| df.score = df.score / (max_abs_score * 1.5) | |
| styled_df = self.make_styled_df(df[["token", "pred"]]) | |
| return list(zip(df.token, df.score)), styled_df | |
| # ===== Cached Loader ===== | |
| def get_runner(): | |
| return SaudiModelRunner() | |
| # ===== Gradio App ===== | |
| def run_model(text: str): | |
| runner = get_runner() | |
| ht, html = runner.generate_highlighted_text(text) | |
| return ht, html | |
| EXAMPLES = [ | |
| ["كيف حالك بشرنا عنك عساك بخير"], | |
| ["رقم جوالي صفر خمسة سبعة ستة ستة واحد ثلاثة سبعة صفر صفر"], | |
| ["او صخره صلبه تستخدم كاساس للمبنى وقال ان الزعماء الدينيين سيرفضون"], | |
| ["هل يمكنك أن تخبرني عن"], | |
| ["جمهورية الدومينيكان هي دولة تقع في الكاريبي على جزيرة هيسبانيولا التي تشترك فيها مع هايتي"], | |
| ] | |
| with gr.Blocks(theme="soft", title="Arabic Turn Detector Debugger") as demo: | |
| gr.Markdown( | |
| """# Arabic Turn Detector Debugger | |
| Visualize predicted turn endings from **Arabic EOU Model**. | |
| Red ⇒ agent should reply • Blue ⇒ agent should wait""" | |
| ) | |
| with gr.Row(): | |
| text_in = gr.Textbox( | |
| label="Input Text", | |
| info="Enter Arabic text to analyze.", | |
| value=EXAMPLES[0][0], | |
| lines=4, | |
| text_align="right", | |
| rtl=True | |
| ) | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[text_in], | |
| label="Examples" | |
| ) | |
| run_btn = gr.Button("Run Analysis", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| out_ht = gr.HighlightedText( | |
| label="EoT Predictions", | |
| color_map="coolwarm", | |
| scale=1.5, | |
| rtl=True | |
| ) | |
| out_html = gr.HTML(label="Raw scores") | |
| run_btn.click( | |
| fn=run_model, | |
| inputs=[text_in], | |
| outputs=[out_ht, out_html] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |