SuperSl6 commited on
Commit
8d2a3cd
·
verified ·
1 Parent(s): 4358f52

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import re
6
+ import string
7
+
8
+ # 1. Model Configuration
9
+ # Updated to point to the final production model
10
+ MODEL_REPO = "SuperSl6/saudi-eou-model-final"
11
+
12
+ print(f"Loading Model from {MODEL_REPO}...")
13
+ try:
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
15
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
16
+ model.eval()
17
+ print("Model loaded successfully.")
18
+ except Exception as e:
19
+ print(f"Error loading model: {e}")
20
+
21
+ # 2. Text Normalization Function
22
+ # Matches the preprocessing steps used during training (removing diacritics, normalizing characters)
23
+ def normalize_text(text):
24
+ text = str(text)
25
+ # Remove Arabic Diacritics
26
+ text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
27
+ # Normalize Alef forms to bare Alef
28
+ text = re.sub(r'[أإآ]', 'ا', text)
29
+ # Normalize Ya forms
30
+ text = re.sub(r'ى', 'ي', text)
31
+ # Remove Punctuation
32
+ translator = str.maketrans('', '', string.punctuation + '،؛؟')
33
+ return text.translate(translator).strip()
34
+
35
+ # 3. Prediction Logic
36
+ def predict_eou(text):
37
+ if not text or not text.strip():
38
+ return "Please enter text...", "0%"
39
+
40
+ clean_text = normalize_text(text)
41
+
42
+ # Safety Rules
43
+ # If the input is less than 2 words, default to WAIT unless it's a specific keyword.
44
+ if len(clean_text.split()) < 2:
45
+ whitelist = ["نعم", "لا", "طيب", "سم", "ابشر", "تم", "صحيح", "اكيد"]
46
+ if clean_text not in whitelist:
47
+ return "WAIT (Turn Incomplete)", "Safety Rule Triggered"
48
+
49
+ # Prepare input for model
50
+ inputs = tokenizer(clean_text, return_tensors="pt", truncation=True, padding=True)
51
+
52
+ with torch.no_grad():
53
+ logits = model(**inputs).logits
54
+ probs = F.softmax(logits, dim=1).numpy()[0]
55
+
56
+ # Label 1 corresponds to COMPLETE (End of Utterance)
57
+ score_complete = probs[1]
58
+
59
+ # Threshold set to 0.5 as the model was trained on balanced data
60
+ threshold = 0.5
61
+
62
+ if score_complete >= threshold:
63
+ decision = "REPLY (Turn Complete)"
64
+ else:
65
+ decision = "WAIT (Turn Incomplete)"
66
+
67
+ return decision, f"{score_complete:.1%}"
68
+
69
+ # 4. Gradio Interface
70
+ examples = [
71
+ ["السلام عليكم"],
72
+ ["ياخي ودي أحجز عند دكتور"],
73
+ ["ياخي ودي أحجز موعد عندكم بكرة"]
74
+ ]
75
+
76
+ iface = gr.Interface(
77
+ fn=predict_eou,
78
+ inputs=gr.Textbox(label="User Speech Input", placeholder="Type here..."),
79
+ outputs=[
80
+ gr.Textbox(label="Agent Decision"),
81
+ gr.Label(label="Confidence Score")
82
+ ],
83
+ title="Saudi Dialect End-of-Utterance (EOU) Detector",
84
+ description="Final Production Model. Uses SaudiBERT to detect turn-taking in real-time conversation.",
85
+ examples=examples,
86
+ allow_flagging="never"
87
+ )
88
+
89
+ iface.launch()