File size: 6,315 Bytes
8252e5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# -*- coding: utf-8 -*-
import emoji
import re
import nltk
import torch
import numpy as np
import joblib
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ── 1. Setup & Pre-processing Environment ──
nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords
arabic_stopwords = set(stopwords.words('arabic'))

def clean_text(text):
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    text = re.sub(r'[@#]', '', text) 
    text = re.sub(r'\d+', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def remove_emojis(text):
    if not isinstance(text, str):
        return text
    return emoji.replace_emoji(text, replace='')

def remove_arabic_punctuation(text):
    if not isinstance(text, str):
        return text
    arabic_punct = (
        r'[\u0600-\u0605\u060C\u060D\u061B\u061C\u061D\u061E\u061F'
        r'\u0640\u066A-\u066D\u06D4\u200c\u200d\u200e\u200f'
        r'\ufeff\u202a-\u202e،؟؛«»]'
    )
    text = re.sub(arabic_punct, ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def normalize_arabic_characters(text):
    if not isinstance(text, str):
        text = str(text)
    text = re.sub(r'[أإآ]', 'ا', text)   
    text = text.replace('ى', 'ي')                    
    text = text.replace('ة', 'ه')                    
    text = re.sub(r'[\u064B-\u065F\u0670]', '', text) 
    return text

def remove_repeated_chars(text):
    if not isinstance(text, str):
        return text
    return re.sub(r'(.)\1{2,}', r'\1\1', text)

def remove_repeated_words(text):
    if not isinstance(text, str):
        return text
    return re.sub(r'\b(\w+)(\s+\1){1,}\b', r'\1', text)

def tokenize_text(text):
    return text.split() if text else []

def remove_stopwords(tokens, stopwords_set=arabic_stopwords):
    return [t for t in tokens if t not in stopwords_set]

def preprocess_arabic_text(text):
    if pd.isna(text) or not isinstance(text, str):
        return ''
    text = clean_text(text)
    text = remove_emojis(text)
    text = remove_arabic_punctuation(text)
    text = normalize_arabic_characters(text)
    text = remove_repeated_chars(text)
    text = remove_repeated_words(text)
    tokens = tokenize_text(text)
    tokens = remove_stopwords(tokens)
    return ' '.join(tokens)

# ── 2. Load Model ──
print("Loading model and tokenizer...")
REPO_ID = "mahmoudmohammad/marbertv2-multilabel-dialect"
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = AutoModelForSequenceClassification.from_pretrained(REPO_ID)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

try:
    mlb = joblib.load('mlb_dialects.pkl')
    class_names = list(mlb.classes_)
except FileNotFoundError:
    class_names = ['Bahraini', 'Egyptian', 'Emirati', 'Jordanian', 'Lebanese', 'MSA', 'Palestinian', 'Qatari', 'Saudi', 'Syrian']


# ── 3. Prediction Pipeline ──
def predict_dialects(text, threshold):
    cleaned_text = preprocess_arabic_text(text)

    inputs = tokenizer(
        cleaned_text,
        return_tensors="pt",
        truncation=True,
        max_length=256
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    probs = torch.sigmoid(logits).squeeze().cpu().numpy()
    predictions = (probs > threshold).astype(int)

    predicted_dialects = []
    predicted_probs = {}

    for i, pred in enumerate(predictions):
        dialect = class_names[i]
        predicted_probs[dialect] = round(float(probs[i]), 4)
        if pred == 1:
            predicted_dialects.append(dialect)

    # Fallback to single highest
    if len(predicted_dialects) == 0:
        max_idx = np.argmax(probs)
        predicted_dialects.append(class_names[max_idx])

    dialects_out = ", ".join(predicted_dialects)
    return dialects_out, predicted_probs, cleaned_text


# ── 4. Gradio UI with Enforced Dark Mode ──

# The following script is a standard hack to forcefully enable dark mode on app load.
dark_mode_js = """

function() {

    document.body.classList.add('dark');

}

"""

with gr.Blocks(theme=gr.themes.Base(), title="Arabic Multi-Dialect Analyzer") as demo:
    gr.Markdown(
        """

        # 🌙 Multi-Label Arabic Dialect Inference

        Identify overlapping dialects in modern Arabic text seamlessly.

        *Powered by MARBERTv2*

        """
    )
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                lines=5, 
                label="Arabic Text Input", 
                placeholder="أدخل النص العربي هنا...", 
                rtl=True
            )
            threshold_slider = gr.Slider(
                minimum=0.1, 
                maximum=0.9, 
                step=0.01, 
                value=0.45, 
                label="Confidence Threshold",
                info="Determines minimum probability needed to label a dialect."
            )
            submit_btn = gr.Button("Analyze Text", variant="primary")
            
        with gr.Column():
            dialects_output = gr.Textbox(label="Predicted Dialect(s)")
            prob_output = gr.Label(label="Confidence Probabilities", num_top_classes=10)
            clean_text_output = gr.Textbox(label="Text After Pre-Processing", rtl=True)
    
    # Optional Examples for fast-testing
    gr.Examples(
        examples=[
            ["شلونك اليوم؟ شو عم تعمل؟", 0.45],
            ["انا رايح الشغل بدري علشان عندي شغل كتير.", 0.45]
        ],
        inputs=[text_input, threshold_slider]
    )
    
    # Map button click to backend logic
    submit_btn.click(
        fn=predict_dialects, 
        inputs=[text_input, threshold_slider], 
        outputs=[dialects_output, prob_output, clean_text_output]
    )

    # Inject JS at initialization to Force Dark Theme
    demo.load(None, None, None, js=dark_mode_js)

if __name__ == "__main__":
    demo.launch()