mahmoudmohammad commited on
Commit
8252e5e
·
verified ·
1 Parent(s): ac067fa

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +194 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import emoji
3
+ import re
4
+ import nltk
5
+ import torch
6
+ import numpy as np
7
+ import joblib
8
+ import pandas as pd
9
+ import gradio as gr
10
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
+
12
+ # ── 1. Setup & Pre-processing Environment ──
13
+ nltk.download('stopwords', quiet=True)
14
+ from nltk.corpus import stopwords
15
+ arabic_stopwords = set(stopwords.words('arabic'))
16
+
17
+ def clean_text(text):
18
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
19
+ text = re.sub(r'[@#]', '', text)
20
+ text = re.sub(r'\d+', '', text)
21
+ text = re.sub(r'\s+', ' ', text).strip()
22
+ return text
23
+
24
+ def remove_emojis(text):
25
+ if not isinstance(text, str):
26
+ return text
27
+ return emoji.replace_emoji(text, replace='')
28
+
29
+ def remove_arabic_punctuation(text):
30
+ if not isinstance(text, str):
31
+ return text
32
+ arabic_punct = (
33
+ r'[\u0600-\u0605\u060C\u060D\u061B\u061C\u061D\u061E\u061F'
34
+ r'\u0640\u066A-\u066D\u06D4\u200c\u200d\u200e\u200f'
35
+ r'\ufeff\u202a-\u202e،؟؛«»]'
36
+ )
37
+ text = re.sub(arabic_punct, ' ', text)
38
+ text = re.sub(r'\s+', ' ', text).strip()
39
+ return text
40
+
41
+ def normalize_arabic_characters(text):
42
+ if not isinstance(text, str):
43
+ text = str(text)
44
+ text = re.sub(r'[أإآ]', 'ا', text)
45
+ text = text.replace('ى', 'ي')
46
+ text = text.replace('ة', 'ه')
47
+ text = re.sub(r'[\u064B-\u065F\u0670]', '', text)
48
+ return text
49
+
50
+ def remove_repeated_chars(text):
51
+ if not isinstance(text, str):
52
+ return text
53
+ return re.sub(r'(.)\1{2,}', r'\1\1', text)
54
+
55
+ def remove_repeated_words(text):
56
+ if not isinstance(text, str):
57
+ return text
58
+ return re.sub(r'\b(\w+)(\s+\1){1,}\b', r'\1', text)
59
+
60
+ def tokenize_text(text):
61
+ return text.split() if text else []
62
+
63
+ def remove_stopwords(tokens, stopwords_set=arabic_stopwords):
64
+ return [t for t in tokens if t not in stopwords_set]
65
+
66
+ def preprocess_arabic_text(text):
67
+ if pd.isna(text) or not isinstance(text, str):
68
+ return ''
69
+ text = clean_text(text)
70
+ text = remove_emojis(text)
71
+ text = remove_arabic_punctuation(text)
72
+ text = normalize_arabic_characters(text)
73
+ text = remove_repeated_chars(text)
74
+ text = remove_repeated_words(text)
75
+ tokens = tokenize_text(text)
76
+ tokens = remove_stopwords(tokens)
77
+ return ' '.join(tokens)
78
+
79
+ # ── 2. Load Model ──
80
+ print("Loading model and tokenizer...")
81
+ REPO_ID = "mahmoudmohammad/marbertv2-multilabel-dialect"
82
+ tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
83
+ model = AutoModelForSequenceClassification.from_pretrained(REPO_ID)
84
+
85
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
86
+ model.to(device)
87
+ model.eval()
88
+
89
+ try:
90
+ mlb = joblib.load('mlb_dialects.pkl')
91
+ class_names = list(mlb.classes_)
92
+ except FileNotFoundError:
93
+ class_names = ['Bahraini', 'Egyptian', 'Emirati', 'Jordanian', 'Lebanese', 'MSA', 'Palestinian', 'Qatari', 'Saudi', 'Syrian']
94
+
95
+
96
+ # ── 3. Prediction Pipeline ──
97
+ def predict_dialects(text, threshold):
98
+ cleaned_text = preprocess_arabic_text(text)
99
+
100
+ inputs = tokenizer(
101
+ cleaned_text,
102
+ return_tensors="pt",
103
+ truncation=True,
104
+ max_length=256
105
+ )
106
+ inputs = {k: v.to(device) for k, v in inputs.items()}
107
+
108
+ with torch.no_grad():
109
+ outputs = model(**inputs)
110
+ logits = outputs.logits
111
+
112
+ probs = torch.sigmoid(logits).squeeze().cpu().numpy()
113
+ predictions = (probs > threshold).astype(int)
114
+
115
+ predicted_dialects = []
116
+ predicted_probs = {}
117
+
118
+ for i, pred in enumerate(predictions):
119
+ dialect = class_names[i]
120
+ predicted_probs[dialect] = round(float(probs[i]), 4)
121
+ if pred == 1:
122
+ predicted_dialects.append(dialect)
123
+
124
+ # Fallback to single highest
125
+ if len(predicted_dialects) == 0:
126
+ max_idx = np.argmax(probs)
127
+ predicted_dialects.append(class_names[max_idx])
128
+
129
+ dialects_out = ", ".join(predicted_dialects)
130
+ return dialects_out, predicted_probs, cleaned_text
131
+
132
+
133
+ # ── 4. Gradio UI with Enforced Dark Mode ──
134
+
135
+ # The following script is a standard hack to forcefully enable dark mode on app load.
136
+ dark_mode_js = """
137
+ function() {
138
+ document.body.classList.add('dark');
139
+ }
140
+ """
141
+
142
+ with gr.Blocks(theme=gr.themes.Base(), title="Arabic Multi-Dialect Analyzer") as demo:
143
+ gr.Markdown(
144
+ """
145
+ # 🌙 Multi-Label Arabic Dialect Inference
146
+ Identify overlapping dialects in modern Arabic text seamlessly.
147
+ *Powered by MARBERTv2*
148
+ """
149
+ )
150
+
151
+ with gr.Row():
152
+ with gr.Column():
153
+ text_input = gr.Textbox(
154
+ lines=5,
155
+ label="Arabic Text Input",
156
+ placeholder="أدخل النص العربي هنا...",
157
+ rtl=True
158
+ )
159
+ threshold_slider = gr.Slider(
160
+ minimum=0.1,
161
+ maximum=0.9,
162
+ step=0.01,
163
+ value=0.45,
164
+ label="Confidence Threshold",
165
+ info="Determines minimum probability needed to label a dialect."
166
+ )
167
+ submit_btn = gr.Button("Analyze Text", variant="primary")
168
+
169
+ with gr.Column():
170
+ dialects_output = gr.Textbox(label="Predicted Dialect(s)")
171
+ prob_output = gr.Label(label="Confidence Probabilities", num_top_classes=10)
172
+ clean_text_output = gr.Textbox(label="Text After Pre-Processing", rtl=True)
173
+
174
+ # Optional Examples for fast-testing
175
+ gr.Examples(
176
+ examples=[
177
+ ["شلونك اليوم؟ شو عم تعمل؟", 0.45],
178
+ ["انا رايح الشغل بدري علشان عندي شغل كتير.", 0.45]
179
+ ],
180
+ inputs=[text_input, threshold_slider]
181
+ )
182
+
183
+ # Map button click to backend logic
184
+ submit_btn.click(
185
+ fn=predict_dialects,
186
+ inputs=[text_input, threshold_slider],
187
+ outputs=[dialects_output, prob_output, clean_text_output]
188
+ )
189
+
190
+ # Inject JS at initialization to Force Dark Theme
191
+ demo.load(None, None, None, js=dark_mode_js)
192
+
193
+ if __name__ == "__main__":
194
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ emoji
5
+ nltk
6
+ pandas
7
+ joblib
8
+ numpy