File size: 2,774 Bytes
8d2a3cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff02950
8d2a3cd
 
 
 
 
 
 
 
 
 
 
 
ff02950
 
 
 
8d2a3cd
 
 
 
 
 
 
 
 
 
 
ff02950
 
8d2a3cd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import re
import string

# 1. Model Configuration
MODEL_REPO = "SuperSl6/saudi-eou-model-final"

print(f"Loading Model from {MODEL_REPO}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
    model.eval()
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model: {e}")

# 2. Text Normalization Function
def normalize_text(text):
    text = str(text)
    # Remove Arabic Diacritics
    text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
    # Normalize Alef forms to bare Alef
    text = re.sub(r'[أإآ]', 'ا', text)
    # Normalize Ya forms
    text = re.sub(r'ى', 'ي', text)
    # Remove Punctuation
    translator = str.maketrans('', '', string.punctuation + '،؛؟')
    return text.translate(translator).strip()

# 3. Prediction Logic
def predict_eou(text):
    if not text or not text.strip():
        return "Please enter text...", "0%"
        
    clean_text = normalize_text(text)
    
    # Safety Rules
    if len(clean_text.split()) < 2:
        whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"]
        if clean_text not in whitelist:
            return "WAIT (Turn Incomplete)", "Safety Rule Triggered"

    # Prepare input for model
    inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True)
    
    with torch.no_grad():
        logits = model(**inputs).logits
        probs = F.softmax(logits, dim=1).numpy()[0]
    
    # Label 1 corresponds to COMPLETE (End of Utterance)
    score_complete = probs[1]
    
    # Threshold set to 0.5
    threshold = 0.5
    
    if score_complete >= threshold:
        decision = "REPLY (Turn Complete)"
    else:
        decision = "WAIT (Turn Incomplete)"
        
    return decision, f"{score_complete:.1%}"

# 4. Gradio Interface
examples = [
    ["السلام عليكم"], 
    ["ياخي ودي أحجز"], 
    ["ياخي ودي أحجز موعد عندكم بكرة"],
    ["رقم جوالي صفر خمسة"],
    ["رقم جوالي صفر خمسة خمسة واحد"]
]

iface = gr.Interface(
    fn=predict_eou,
    inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."),
    outputs=[
        gr.Textbox(label="Agent Decision"),
        gr.Label(label="Confidence Score")
    ],
    title="Saudi Dialect End-of-Utterance (EOU) Detector",
    description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.",
    examples=examples
    # Deleted 'allow_flagging' to fix the error
)

iface.launch()