File size: 5,933 Bytes
17f2039
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ============================================
# 1. Configuration & Label Mapping
# ============================================
MODEL_ID = "mahmoudmohammad/marbertv2_single-label-dialect"

# The exact label map mapped during your training
LABEL_MAP = {
    0: 'Algerian', 1: 'Egyptian', 2: 'Iraqi', 3: 'Jordanian', 
    4: 'Lebanese', 5: 'Libyan', 6: 'MSA', 7: 'Moroccan', 
    8: 'Palestinian', 9: 'Qatari', 10: 'Saudi', 11: 'Syrian', 
    12: 'Tunisian', 13: 'Yemeni'
}

# ============================================
# 2. Caching & Loading Model Locally
# ============================================
# Defining them at the module level loads them once during Space spin-up
# making all future inferences blazingly fast.
print(f"Loading {MODEL_ID} from Hugging Face...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
    model.eval() # Ensure dropout layers are frozen
    print("✅ Model loaded successfully!")
except Exception as e:
    print(f"❌ Error loading model: {e}")

# ============================================
# 3. Preprocessing Logic
# ============================================
def preprocess_arabic_dialect(text: str) -> str:
    """Cleans social media dialectal Arabic text. Exact copy from training script."""
    if not isinstance(text, str):
        return ""

    text = re.sub(r'http\S+|www\.\S+|<.*?>', ' ', text)
    text = re.sub(r'@\w+', ' ', text)
    text = re.sub(r'#', '', text)
    tashkeel = re.compile(r'[\u0617-\u061A\u064B-\u0652]')
    text = re.sub(tashkeel, '', text)
    text = re.sub(r'\u0640', '', text)
    text = re.sub(r'(.)\1+', r'\1\1', text)
    text = re.sub(r'[^\w\s\u0600-\u06FF]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# ============================================
# 4. Inference Function
# ============================================
def predict_dialect(text: str):
    if not text.strip():
        # Handle empty text gently
        return {label: 0.0 for label in LABEL_MAP.values()}

    # 1. Clean the incoming text
    clean_text = preprocess_arabic_dialect(text)

    # 2. Tokenize (ensuring dimensions align with max_len 128)
    inputs = tokenizer(
        clean_text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=128, 
        padding="max_length" # As trained in the model script
    )

    # 3. Model Inference (No Gradient tracking)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        # Calculate Softmax Probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)[0]
    
    # 4. Format into a Dictionary for the Gradio 'Label' UI
    # Gradio will use these numbers to automatically populate prediction progress bars
    results = {LABEL_MAP[i]: float(probs[i]) for i in range(len(LABEL_MAP))}
    
    return results

# ============================================
# 5. UI Application Definition (Dark Mode Native)
# ============================================

# Dark mode snippet using Gradio js injection
dark_mode_js = """

function() { 

    document.body.classList.add('dark');

}

"""

with gr.Blocks(js=dark_mode_js, theme=gr.themes.Monochrome(primary_hue="purple")) as demo:
    gr.Markdown("# 🌍 Arabic Dialect Detector")
    gr.Markdown("Identify whether text represents **MSA** or one of 13 Regional **Arabic Dialects** (e.g., Egyptian, Saudi, Moroccan, Lebanese...). \n*Powered by a Fine-Tuned MARBERTv2 base model.*")
    
    with gr.Row():
        
        # Left Panel (Inputs and Buttons)
        with gr.Column(scale=5):
            text_input = gr.Textbox(
                label="أدخل النص (Enter Arabic Text Here)", 
                placeholder="إزيك يا صاحبي عامل إيه؟",
                lines=5
            )
            submit_btn = gr.Button("Detect Dialect 🔎", variant="primary")
            
            # Diverse dialect examples to populate inside the Space
            examples_list = [
                ["إزيك يا صاحبي عامل إيه؟ فينك من زمان"],           # Egyptian
                ["شو أخبارك؟ وين هالغيبة اشتقنالك كتير"],          # Lebanese/Syrian
                ["كيداير لاباس عليك؟ شنو كتدير؟"],                 # Moroccan
                ["وشلونك طال عمرك؟ عساك طيب ومبسوط"],            # Saudi / Gulf
                ["السلام عليكم ورحمة الله وبركاته، كيف حالكم اليوم؟"], # MSA
                ["أنا هسا رايح عالدار بدك اشي؟"],                  # Jordanian/Palestinian
            ]
            gr.Examples(
                examples=examples_list,
                inputs=text_input,
                label="Try these Examples"
            )

        # Right Panel (Output Predictions Bar)
        with gr.Column(scale=4):
            # Showing Top 4 detected probabilities smoothly
            output_labels = gr.Label(num_top_classes=4, label="Dialect Confidence")
            
            # Just to show preprocessing mapping in backend visually to users
            gr.Markdown("*(Internal Text pre-processing strips tags, mentions, tashkeel, repeated letters etc. via REGEX just like the model training before execution!)*")

    # Connect UI button -> Inference Logic
    submit_btn.click(
        fn=predict_dialect,
        inputs=text_input,
        outputs=output_labels
    )
    
# Boot Gradio Application
if __name__ == "__main__":
    # Ensure memory handling on Gradio hosting wrapper
    demo.launch(show_error=True)