import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F import re import string # 1. Model Repository Configuration MODEL_REPO = "SuperSl6/saudi-eou-model-v1" print(f"Loading Model from {MODEL_REPO}...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) model.eval() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # 2. Text Normalization Function # This matches the preprocessing steps used during training to ensure accuracy. def normalize_text(text): text = str(text) # Remove Arabic Diacritics (Tashkeel) text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text) # Normalize Alef forms (unify to bare Alef) text = re.sub(r'[أإآ]', 'ا', text) # Normalize Ya forms text = re.sub(r'ى', 'ي', text) # Remove Punctuation translator = str.maketrans('', '', string.punctuation + '،؛؟') text = text.translate(translator) return text.strip() # 3. Prediction Function def predict_eou(text): if not text or not text.strip(): return "Please enter text...", "0%" # Clean text before inference clean_text = normalize_text(text) inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=1).numpy()[0] # Label 1 represents COMPLETE (EOU) score_complete = probs[1] # Define Confidence Threshold threshold = 0.75 if score_complete >= threshold: decision = "REPLY (Turn Complete)" else: decision = "WAIT (Turn Incomplete)" # Return formatted decision and confidence percentage return decision, f"{score_complete:.1%}" # 4. Gradio Interface Setup examples = [ ["السلام عليكم"], ["ياخي ودي أحجز"], ["ياخي ودي أحجز موعد عندكم بكرة"], ["رقم جوالي صفر خمسة"] ] iface = gr.Interface( fn=predict_eou, inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."), outputs=[ gr.Textbox(label="Model Decision"), gr.Label(label="Confidence Score") ], title="Saudi Dialect End-of-Utterance (EOU) Detector", description="A fine-tuned SaudiBERT model designed to detect end-of-utterance in real-time Saudi dialect conversations for AI voice agents.", examples=examples ) iface.launch()