mahmoudmohammad's picture
Upload 2 files
17f2039 verified
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ============================================
# 1. Configuration & Label Mapping
# ============================================
MODEL_ID = "mahmoudmohammad/marbertv2_single-label-dialect"
# The exact label map mapped during your training
LABEL_MAP = {
0: 'Algerian', 1: 'Egyptian', 2: 'Iraqi', 3: 'Jordanian',
4: 'Lebanese', 5: 'Libyan', 6: 'MSA', 7: 'Moroccan',
8: 'Palestinian', 9: 'Qatari', 10: 'Saudi', 11: 'Syrian',
12: 'Tunisian', 13: 'Yemeni'
}
# ============================================
# 2. Caching & Loading Model Locally
# ============================================
# Defining them at the module level loads them once during Space spin-up
# making all future inferences blazingly fast.
print(f"Loading {MODEL_ID} from Hugging Face...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
model.eval() # Ensure dropout layers are frozen
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Error loading model: {e}")
# ============================================
# 3. Preprocessing Logic
# ============================================
def preprocess_arabic_dialect(text: str) -> str:
"""Cleans social media dialectal Arabic text. Exact copy from training script."""
if not isinstance(text, str):
return ""
text = re.sub(r'http\S+|www\.\S+|<.*?>', ' ', text)
text = re.sub(r'@\w+', ' ', text)
text = re.sub(r'#', '', text)
tashkeel = re.compile(r'[\u0617-\u061A\u064B-\u0652]')
text = re.sub(tashkeel, '', text)
text = re.sub(r'\u0640', '', text)
text = re.sub(r'(.)\1+', r'\1\1', text)
text = re.sub(r'[^\w\s\u0600-\u06FF]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ============================================
# 4. Inference Function
# ============================================
def predict_dialect(text: str):
if not text.strip():
# Handle empty text gently
return {label: 0.0 for label in LABEL_MAP.values()}
# 1. Clean the incoming text
clean_text = preprocess_arabic_dialect(text)
# 2. Tokenize (ensuring dimensions align with max_len 128)
inputs = tokenizer(
clean_text,
return_tensors="pt",
truncation=True,
max_length=128,
padding="max_length" # As trained in the model script
)
# 3. Model Inference (No Gradient tracking)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Calculate Softmax Probabilities
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
# 4. Format into a Dictionary for the Gradio 'Label' UI
# Gradio will use these numbers to automatically populate prediction progress bars
results = {LABEL_MAP[i]: float(probs[i]) for i in range(len(LABEL_MAP))}
return results
# ============================================
# 5. UI Application Definition (Dark Mode Native)
# ============================================
# Dark mode snippet using Gradio js injection
dark_mode_js = """
function() {
document.body.classList.add('dark');
}
"""
with gr.Blocks(js=dark_mode_js, theme=gr.themes.Monochrome(primary_hue="purple")) as demo:
gr.Markdown("# 🌍 Arabic Dialect Detector")
gr.Markdown("Identify whether text represents **MSA** or one of 13 Regional **Arabic Dialects** (e.g., Egyptian, Saudi, Moroccan, Lebanese...). \n*Powered by a Fine-Tuned MARBERTv2 base model.*")
with gr.Row():
# Left Panel (Inputs and Buttons)
with gr.Column(scale=5):
text_input = gr.Textbox(
label="أدخل النص (Enter Arabic Text Here)",
placeholder="إزيك يا صاحبي عامل إيه؟",
lines=5
)
submit_btn = gr.Button("Detect Dialect 🔎", variant="primary")
# Diverse dialect examples to populate inside the Space
examples_list = [
["إزيك يا صاحبي عامل إيه؟ فينك من زمان"], # Egyptian
["شو أخبارك؟ وين هالغيبة اشتقنالك كتير"], # Lebanese/Syrian
["كيداير لاباس عليك؟ شنو كتدير؟"], # Moroccan
["وشلونك طال عمرك؟ عساك طيب ومبسوط"], # Saudi / Gulf
["السلام عليكم ورحمة الله وبركاته، كيف حالكم اليوم؟"], # MSA
["أنا هسا رايح عالدار بدك اشي؟"], # Jordanian/Palestinian
]
gr.Examples(
examples=examples_list,
inputs=text_input,
label="Try these Examples"
)
# Right Panel (Output Predictions Bar)
with gr.Column(scale=4):
# Showing Top 4 detected probabilities smoothly
output_labels = gr.Label(num_top_classes=4, label="Dialect Confidence")
# Just to show preprocessing mapping in backend visually to users
gr.Markdown("*(Internal Text pre-processing strips tags, mentions, tashkeel, repeated letters etc. via REGEX just like the model training before execution!)*")
# Connect UI button -> Inference Logic
submit_btn.click(
fn=predict_dialect,
inputs=text_input,
outputs=output_labels
)
# Boot Gradio Application
if __name__ == "__main__":
# Ensure memory handling on Gradio hosting wrapper
demo.launch(show_error=True)