mahmoudmohammad's picture
Upload 2 files
8252e5e verified
# -*- coding: utf-8 -*-
import emoji
import re
import nltk
import torch
import numpy as np
import joblib
import pandas as pd
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# ── 1. Setup & Pre-processing Environment ──
nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords
arabic_stopwords = set(stopwords.words('arabic'))
def clean_text(text):
text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
text = re.sub(r'[@#]', '', text)
text = re.sub(r'\d+', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def remove_emojis(text):
if not isinstance(text, str):
return text
return emoji.replace_emoji(text, replace='')
def remove_arabic_punctuation(text):
if not isinstance(text, str):
return text
arabic_punct = (
r'[\u0600-\u0605\u060C\u060D\u061B\u061C\u061D\u061E\u061F'
r'\u0640\u066A-\u066D\u06D4\u200c\u200d\u200e\u200f'
r'\ufeff\u202a-\u202e،؟؛«»]'
)
text = re.sub(arabic_punct, ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def normalize_arabic_characters(text):
if not isinstance(text, str):
text = str(text)
text = re.sub(r'[أإآ]', 'ا', text)
text = text.replace('ى', 'ي')
text = text.replace('ة', 'ه')
text = re.sub(r'[\u064B-\u065F\u0670]', '', text)
return text
def remove_repeated_chars(text):
if not isinstance(text, str):
return text
return re.sub(r'(.)\1{2,}', r'\1\1', text)
def remove_repeated_words(text):
if not isinstance(text, str):
return text
return re.sub(r'\b(\w+)(\s+\1){1,}\b', r'\1', text)
def tokenize_text(text):
return text.split() if text else []
def remove_stopwords(tokens, stopwords_set=arabic_stopwords):
return [t for t in tokens if t not in stopwords_set]
def preprocess_arabic_text(text):
if pd.isna(text) or not isinstance(text, str):
return ''
text = clean_text(text)
text = remove_emojis(text)
text = remove_arabic_punctuation(text)
text = normalize_arabic_characters(text)
text = remove_repeated_chars(text)
text = remove_repeated_words(text)
tokens = tokenize_text(text)
tokens = remove_stopwords(tokens)
return ' '.join(tokens)
# ── 2. Load Model ──
print("Loading model and tokenizer...")
REPO_ID = "mahmoudmohammad/marbertv2-multilabel-dialect"
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = AutoModelForSequenceClassification.from_pretrained(REPO_ID)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
try:
mlb = joblib.load('mlb_dialects.pkl')
class_names = list(mlb.classes_)
except FileNotFoundError:
class_names = ['Bahraini', 'Egyptian', 'Emirati', 'Jordanian', 'Lebanese', 'MSA', 'Palestinian', 'Qatari', 'Saudi', 'Syrian']
# ── 3. Prediction Pipeline ──
def predict_dialects(text, threshold):
cleaned_text = preprocess_arabic_text(text)
inputs = tokenizer(
cleaned_text,
return_tensors="pt",
truncation=True,
max_length=256
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.sigmoid(logits).squeeze().cpu().numpy()
predictions = (probs > threshold).astype(int)
predicted_dialects = []
predicted_probs = {}
for i, pred in enumerate(predictions):
dialect = class_names[i]
predicted_probs[dialect] = round(float(probs[i]), 4)
if pred == 1:
predicted_dialects.append(dialect)
# Fallback to single highest
if len(predicted_dialects) == 0:
max_idx = np.argmax(probs)
predicted_dialects.append(class_names[max_idx])
dialects_out = ", ".join(predicted_dialects)
return dialects_out, predicted_probs, cleaned_text
# ── 4. Gradio UI with Enforced Dark Mode ──
# The following script is a standard hack to forcefully enable dark mode on app load.
dark_mode_js = """
function() {
document.body.classList.add('dark');
}
"""
with gr.Blocks(theme=gr.themes.Base(), title="Arabic Multi-Dialect Analyzer") as demo:
gr.Markdown(
"""
# 🌙 Multi-Label Arabic Dialect Inference
Identify overlapping dialects in modern Arabic text seamlessly.
*Powered by MARBERTv2*
"""
)
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
lines=5,
label="Arabic Text Input",
placeholder="أدخل النص العربي هنا...",
rtl=True
)
threshold_slider = gr.Slider(
minimum=0.1,
maximum=0.9,
step=0.01,
value=0.45,
label="Confidence Threshold",
info="Determines minimum probability needed to label a dialect."
)
submit_btn = gr.Button("Analyze Text", variant="primary")
with gr.Column():
dialects_output = gr.Textbox(label="Predicted Dialect(s)")
prob_output = gr.Label(label="Confidence Probabilities", num_top_classes=10)
clean_text_output = gr.Textbox(label="Text After Pre-Processing", rtl=True)
# Optional Examples for fast-testing
gr.Examples(
examples=[
["شلونك اليوم؟ شو عم تعمل؟", 0.45],
["انا رايح الشغل بدري علشان عندي شغل كتير.", 0.45]
],
inputs=[text_input, threshold_slider]
)
# Map button click to backend logic
submit_btn.click(
fn=predict_dialects,
inputs=[text_input, threshold_slider],
outputs=[dialects_output, prob_output, clean_text_output]
)
# Inject JS at initialization to Force Dark Theme
demo.load(None, None, None, js=dark_mode_js)
if __name__ == "__main__":
demo.launch()