Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import torch.nn.functional as F | |
| import re | |
| import string | |
| # 1. Model Configuration | |
| MODEL_REPO = "SuperSl6/saudi-eou-model-final" | |
| print(f"Loading Model from {MODEL_REPO}...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO) | |
| model.eval() | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # 2. Text Normalization Function | |
| def normalize_text(text): | |
| text = str(text) | |
| # Remove Arabic Diacritics | |
| text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text) | |
| # Normalize Alef forms to bare Alef | |
| text = re.sub(r'[أإآ]', 'ا', text) | |
| # Normalize Ya forms | |
| text = re.sub(r'ى', 'ي', text) | |
| # Remove Punctuation | |
| translator = str.maketrans('', '', string.punctuation + '،؛؟') | |
| return text.translate(translator).strip() | |
| # 3. Prediction Logic | |
| def predict_eou(text): | |
| if not text or not text.strip(): | |
| return "Please enter text...", "0%" | |
| clean_text = normalize_text(text) | |
| # Safety Rules | |
| if len(clean_text.split()) < 2: | |
| whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"] | |
| if clean_text not in whitelist: | |
| return "WAIT (Turn Incomplete)", "Safety Rule Triggered" | |
| # Prepare input for model | |
| inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=1).numpy()[0] | |
| # Label 1 corresponds to COMPLETE (End of Utterance) | |
| score_complete = probs[1] | |
| # Threshold set to 0.5 | |
| threshold = 0.5 | |
| if score_complete >= threshold: | |
| decision = "REPLY (Turn Complete)" | |
| else: | |
| decision = "WAIT (Turn Incomplete)" | |
| return decision, f"{score_complete:.1%}" | |
| # 4. Gradio Interface | |
| examples = [ | |
| ["السلام عليكم"], | |
| ["ياخي ودي أحجز"], | |
| ["ياخي ودي أحجز موعد عندكم بكرة"], | |
| ["رقم جوالي صفر خمسة"], | |
| ["رقم جوالي صفر خمسة خمسة واحد"] | |
| ] | |
| iface = gr.Interface( | |
| fn=predict_eou, | |
| inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."), | |
| outputs=[ | |
| gr.Textbox(label="Agent Decision"), | |
| gr.Label(label="Confidence Score") | |
| ], | |
| title="Saudi Dialect End-of-Utterance (EOU) Detector", | |
| description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.", | |
| examples=examples | |
| # Deleted 'allow_flagging' to fix the error | |
| ) | |
| iface.launch() |