Spaces:
Sleeping
Sleeping
File size: 2,774 Bytes
8d2a3cd ff02950 8d2a3cd ff02950 8d2a3cd ff02950 8d2a3cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
import string
# 1. Model Configuration
MODEL_REPO = "SuperSl6/saudi-eou-model-final"
print(f"Loading Model from {MODEL_REPO}...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
model.eval()
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
# 2. Text Normalization Function
def normalize_text(text):
text = str(text)
# Remove Arabic Diacritics
text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
# Normalize Alef forms to bare Alef
text = re.sub(r'[أإآ]', 'ا', text)
# Normalize Ya forms
text = re.sub(r'ى', 'ي', text)
# Remove Punctuation
translator = str.maketrans('', '', string.punctuation + '،؛؟')
return text.translate(translator).strip()
# 3. Prediction Logic
def predict_eou(text):
if not text or not text.strip():
return "Please enter text...", "0%"
clean_text = normalize_text(text)
# Safety Rules
if len(clean_text.split()) < 2:
whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"]
if clean_text not in whitelist:
return "WAIT (Turn Incomplete)", "Safety Rule Triggered"
# Prepare input for model
inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=1).numpy()[0]
# Label 1 corresponds to COMPLETE (End of Utterance)
score_complete = probs[1]
# Threshold set to 0.5
threshold = 0.5
if score_complete >= threshold:
decision = "REPLY (Turn Complete)"
else:
decision = "WAIT (Turn Incomplete)"
return decision, f"{score_complete:.1%}"
# 4. Gradio Interface
examples = [
["السلام عليكم"],
["ياخي ودي أحجز"],
["ياخي ودي أحجز موعد عندكم بكرة"],
["رقم جوالي صفر خمسة"],
["رقم جوالي صفر خمسة خمسة واحد"]
]
iface = gr.Interface(
fn=predict_eou,
inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."),
outputs=[
gr.Textbox(label="Agent Decision"),
gr.Label(label="Confidence Score")
],
title="Saudi Dialect End-of-Utterance (EOU) Detector",
description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.",
examples=examples
# Deleted 'allow_flagging' to fix the error
)
iface.launch() |