import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F import re import string # 1. Model Configuration MODEL_REPO = "SuperSl6/saudi-eou-model-final" print(f"Loading Model from {MODEL_REPO}...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) model.eval() print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") # 2. Text Normalization Function def normalize_text(text): text = str(text) # Remove Arabic Diacritics text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text) # Normalize Alef forms to bare Alef text = re.sub(r'[أإآ]', 'ا', text) # Normalize Ya forms text = re.sub(r'ى', 'ي', text) # Remove Punctuation translator = str.maketrans('', '', string.punctuation + '،؛؟') return text.translate(translator).strip() # 3. Prediction Logic def predict_eou(text): if not text or not text.strip(): return "Please enter text...", "0%" clean_text = normalize_text(text) # Safety Rules if len(clean_text.split()) < 2: whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"] if clean_text not in whitelist: return "WAIT (Turn Incomplete)", "Safety Rule Triggered" # Prepare input for model inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1).numpy()[0] # Label 1 corresponds to COMPLETE (End of Utterance) score_complete = probs[1] # Threshold set to 0.5 threshold = 0.5 if score_complete >= threshold: decision = "REPLY (Turn Complete)" else: decision = "WAIT (Turn Incomplete)" return decision, f"{score_complete:.1%}" # 4. Gradio Interface examples = [ ["السلام عليكم"], ["ياخي ودي أحجز"], ["ياخي ودي أحجز موعد عندكم بكرة"], ["رقم جوالي صفر خمسة"], ["رقم جوالي صفر خمسة خمسة واحد"] ] iface = gr.Interface( fn=predict_eou, inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."), outputs=[ gr.Textbox(label="Agent Decision"), gr.Label(label="Confidence Score") ], title="Saudi Dialect End-of-Utterance (EOU) Detector", description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.", examples=examples # Deleted 'allow_flagging' to fix the error ) iface.launch()