SuperSl6's picture
Update app.py
ff02950 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
import string
# 1. Model Configuration
MODEL_REPO = "SuperSl6/saudi-eou-model-final"
print(f"Loading Model from {MODEL_REPO}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# 2. Text Normalization Function
def normalize_text(text):
text = str(text)
# Remove Arabic Diacritics
text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
# Normalize Alef forms to bare Alef
text = re.sub(r'[أإآ]', 'ا', text)
# Normalize Ya forms
text = re.sub(r'ى', 'ي', text)
# Remove Punctuation
translator = str.maketrans('', '', string.punctuation + '،؛؟')
return text.translate(translator).strip()
# 3. Prediction Logic
def predict_eou(text):
if not text or not text.strip():
return "Please enter text...", "0%"
clean_text = normalize_text(text)
# Safety Rules
if len(clean_text.split()) < 2:
whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"]
if clean_text not in whitelist:
return "WAIT (Turn Incomplete)", "Safety Rule Triggered"
# Prepare input for model
inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).numpy()[0]
# Label 1 corresponds to COMPLETE (End of Utterance)
score_complete = probs[1]
# Threshold set to 0.5
threshold = 0.5
if score_complete >= threshold:
decision = "REPLY (Turn Complete)"
else:
decision = "WAIT (Turn Incomplete)"
return decision, f"{score_complete:.1%}"
# 4. Gradio Interface
examples = [
["السلام عليكم"],
["ياخي ودي أحجز"],
["ياخي ودي أحجز موعد عندكم بكرة"],
["رقم جوالي صفر خمسة"],
["رقم جوالي صفر خمسة خمسة واحد"]
]
iface = gr.Interface(
fn=predict_eou,
inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."),
outputs=[
gr.Textbox(label="Agent Decision"),
gr.Label(label="Confidence Score")
],
title="Saudi Dialect End-of-Utterance (EOU) Detector",
description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.",
examples=examples
# Deleted 'allow_flagging' to fix the error
)
iface.launch()