# -*- coding: utf-8 -*- import emoji import re import nltk import torch import numpy as np import joblib import pandas as pd import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification # ── 1. Setup & Pre-processing Environment ── nltk.download('stopwords', quiet=True) from nltk.corpus import stopwords arabic_stopwords = set(stopwords.words('arabic')) def clean_text(text): text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) text = re.sub(r'[@#]', '', text) text = re.sub(r'\d+', '', text) text = re.sub(r'\s+', ' ', text).strip() return text def remove_emojis(text): if not isinstance(text, str): return text return emoji.replace_emoji(text, replace='') def remove_arabic_punctuation(text): if not isinstance(text, str): return text arabic_punct = ( r'[\u0600-\u0605\u060C\u060D\u061B\u061C\u061D\u061E\u061F' r'\u0640\u066A-\u066D\u06D4\u200c\u200d\u200e\u200f' r'\ufeff\u202a-\u202e،؟؛«»]' ) text = re.sub(arabic_punct, ' ', text) text = re.sub(r'\s+', ' ', text).strip() return text def normalize_arabic_characters(text): if not isinstance(text, str): text = str(text) text = re.sub(r'[أإآ]', 'ا', text) text = text.replace('ى', 'ي') text = text.replace('ة', 'ه') text = re.sub(r'[\u064B-\u065F\u0670]', '', text) return text def remove_repeated_chars(text): if not isinstance(text, str): return text return re.sub(r'(.)\1{2,}', r'\1\1', text) def remove_repeated_words(text): if not isinstance(text, str): return text return re.sub(r'\b(\w+)(\s+\1){1,}\b', r'\1', text) def tokenize_text(text): return text.split() if text else [] def remove_stopwords(tokens, stopwords_set=arabic_stopwords): return [t for t in tokens if t not in stopwords_set] def preprocess_arabic_text(text): if pd.isna(text) or not isinstance(text, str): return '' text = clean_text(text) text = remove_emojis(text) text = remove_arabic_punctuation(text) text = normalize_arabic_characters(text) text = remove_repeated_chars(text) text = remove_repeated_words(text) tokens = tokenize_text(text) tokens = remove_stopwords(tokens) return ' '.join(tokens) # ── 2. Load Model ── print("Loading model and tokenizer...") REPO_ID = "mahmoudmohammad/marbertv2-multilabel-dialect" tokenizer = AutoTokenizer.from_pretrained(REPO_ID) model = AutoModelForSequenceClassification.from_pretrained(REPO_ID) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) model.eval() try: mlb = joblib.load('mlb_dialects.pkl') class_names = list(mlb.classes_) except FileNotFoundError: class_names = ['Bahraini', 'Egyptian', 'Emirati', 'Jordanian', 'Lebanese', 'MSA', 'Palestinian', 'Qatari', 'Saudi', 'Syrian'] # ── 3. Prediction Pipeline ── def predict_dialects(text, threshold): cleaned_text = preprocess_arabic_text(text) inputs = tokenizer( cleaned_text, return_tensors="pt", truncation=True, max_length=256 ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.sigmoid(logits).squeeze().cpu().numpy() predictions = (probs > threshold).astype(int) predicted_dialects = [] predicted_probs = {} for i, pred in enumerate(predictions): dialect = class_names[i] predicted_probs[dialect] = round(float(probs[i]), 4) if pred == 1: predicted_dialects.append(dialect) # Fallback to single highest if len(predicted_dialects) == 0: max_idx = np.argmax(probs) predicted_dialects.append(class_names[max_idx]) dialects_out = ", ".join(predicted_dialects) return dialects_out, predicted_probs, cleaned_text # ── 4. Gradio UI with Enforced Dark Mode ── # The following script is a standard hack to forcefully enable dark mode on app load. dark_mode_js = """ function() { document.body.classList.add('dark'); } """ with gr.Blocks(theme=gr.themes.Base(), title="Arabic Multi-Dialect Analyzer") as demo: gr.Markdown( """ # 🌙 Multi-Label Arabic Dialect Inference Identify overlapping dialects in modern Arabic text seamlessly. *Powered by MARBERTv2* """ ) with gr.Row(): with gr.Column(): text_input = gr.Textbox( lines=5, label="Arabic Text Input", placeholder="أدخل النص العربي هنا...", rtl=True ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, step=0.01, value=0.45, label="Confidence Threshold", info="Determines minimum probability needed to label a dialect." ) submit_btn = gr.Button("Analyze Text", variant="primary") with gr.Column(): dialects_output = gr.Textbox(label="Predicted Dialect(s)") prob_output = gr.Label(label="Confidence Probabilities", num_top_classes=10) clean_text_output = gr.Textbox(label="Text After Pre-Processing", rtl=True) # Optional Examples for fast-testing gr.Examples( examples=[ ["شلونك اليوم؟ شو عم تعمل؟", 0.45], ["انا رايح الشغل بدري علشان عندي شغل كتير.", 0.45] ], inputs=[text_input, threshold_slider] ) # Map button click to backend logic submit_btn.click( fn=predict_dialects, inputs=[text_input, threshold_slider], outputs=[dialects_output, prob_output, clean_text_output] ) # Inject JS at initialization to Force Dark Theme demo.load(None, None, None, js=dark_mode_js) if __name__ == "__main__": demo.launch()