saudi-eou-demo / app.py
SuperSl6's picture
Update app.py
26ca26d verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
import string
# 1. Model Repository Configuration
MODEL_REPO = "SuperSl6/saudi-eou-model-v1"
print(f"Loading Model from {MODEL_REPO}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
model.eval()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# 2. Text Normalization Function
# This matches the preprocessing steps used during training to ensure accuracy.
def normalize_text(text):
text = str(text)
# Remove Arabic Diacritics (Tashkeel)
text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
# Normalize Alef forms (unify to bare Alef)
text = re.sub(r'[أإآ]', 'ا', text)
# Normalize Ya forms
text = re.sub(r'ى', 'ي', text)
# Remove Punctuation
translator = str.maketrans('', '', string.punctuation + '،؛؟')
text = text.translate(translator)
return text.strip()
# 3. Prediction Function
def predict_eou(text):
if not text or not text.strip():
return "Please enter text...", "0%"
# Clean text before inference
clean_text = normalize_text(text)
inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).numpy()[0]
# Label 1 represents COMPLETE (EOU)
score_complete = probs[1]
# Define Confidence Threshold
threshold = 0.75
if score_complete >= threshold:
decision = "REPLY (Turn Complete)"
else:
decision = "WAIT (Turn Incomplete)"
# Return formatted decision and confidence percentage
return decision, f"{score_complete:.1%}"
# 4. Gradio Interface Setup
examples = [
["السلام عليكم"],
["ياخي ودي أحجز"],
["ياخي ودي أحجز موعد عندكم بكرة"],
["رقم جوالي صفر خمسة"]
]
iface = gr.Interface(
fn=predict_eou,
inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."),
outputs=[
gr.Textbox(label="Model Decision"),
gr.Label(label="Confidence Score")
],
title="Saudi Dialect End-of-Utterance (EOU) Detector",
description="A fine-tuned SaudiBERT model designed to detect end-of-utterance in real-time Saudi dialect conversations for AI voice agents.",
examples=examples
)
iface.launch()