Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| import emoji | |
| import re | |
| import nltk | |
| import torch | |
| import numpy as np | |
| import joblib | |
| import pandas as pd | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # ── 1. Setup & Pre-processing Environment ── | |
| nltk.download('stopwords', quiet=True) | |
| from nltk.corpus import stopwords | |
| arabic_stopwords = set(stopwords.words('arabic')) | |
| def clean_text(text): | |
| text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE) | |
| text = re.sub(r'[@#]', '', text) | |
| text = re.sub(r'\d+', '', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def remove_emojis(text): | |
| if not isinstance(text, str): | |
| return text | |
| return emoji.replace_emoji(text, replace='') | |
| def remove_arabic_punctuation(text): | |
| if not isinstance(text, str): | |
| return text | |
| arabic_punct = ( | |
| r'[\u0600-\u0605\u060C\u060D\u061B\u061C\u061D\u061E\u061F' | |
| r'\u0640\u066A-\u066D\u06D4\u200c\u200d\u200e\u200f' | |
| r'\ufeff\u202a-\u202e،؟؛«»]' | |
| ) | |
| text = re.sub(arabic_punct, ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| def normalize_arabic_characters(text): | |
| if not isinstance(text, str): | |
| text = str(text) | |
| text = re.sub(r'[أإآ]', 'ا', text) | |
| text = text.replace('ى', 'ي') | |
| text = text.replace('ة', 'ه') | |
| text = re.sub(r'[\u064B-\u065F\u0670]', '', text) | |
| return text | |
| def remove_repeated_chars(text): | |
| if not isinstance(text, str): | |
| return text | |
| return re.sub(r'(.)\1{2,}', r'\1\1', text) | |
| def remove_repeated_words(text): | |
| if not isinstance(text, str): | |
| return text | |
| return re.sub(r'\b(\w+)(\s+\1){1,}\b', r'\1', text) | |
| def tokenize_text(text): | |
| return text.split() if text else [] | |
| def remove_stopwords(tokens, stopwords_set=arabic_stopwords): | |
| return [t for t in tokens if t not in stopwords_set] | |
| def preprocess_arabic_text(text): | |
| if pd.isna(text) or not isinstance(text, str): | |
| return '' | |
| text = clean_text(text) | |
| text = remove_emojis(text) | |
| text = remove_arabic_punctuation(text) | |
| text = normalize_arabic_characters(text) | |
| text = remove_repeated_chars(text) | |
| text = remove_repeated_words(text) | |
| tokens = tokenize_text(text) | |
| tokens = remove_stopwords(tokens) | |
| return ' '.join(tokens) | |
| # ── 2. Load Model ── | |
| print("Loading model and tokenizer...") | |
| REPO_ID = "mahmoudmohammad/marbertv2-multilabel-dialect" | |
| tokenizer = AutoTokenizer.from_pretrained(REPO_ID) | |
| model = AutoModelForSequenceClassification.from_pretrained(REPO_ID) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| model.eval() | |
| try: | |
| mlb = joblib.load('mlb_dialects.pkl') | |
| class_names = list(mlb.classes_) | |
| except FileNotFoundError: | |
| class_names = ['Bahraini', 'Egyptian', 'Emirati', 'Jordanian', 'Lebanese', 'MSA', 'Palestinian', 'Qatari', 'Saudi', 'Syrian'] | |
| # ── 3. Prediction Pipeline ── | |
| def predict_dialects(text, threshold): | |
| cleaned_text = preprocess_arabic_text(text) | |
| inputs = tokenizer( | |
| cleaned_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=256 | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.sigmoid(logits).squeeze().cpu().numpy() | |
| predictions = (probs > threshold).astype(int) | |
| predicted_dialects = [] | |
| predicted_probs = {} | |
| for i, pred in enumerate(predictions): | |
| dialect = class_names[i] | |
| predicted_probs[dialect] = round(float(probs[i]), 4) | |
| if pred == 1: | |
| predicted_dialects.append(dialect) | |
| # Fallback to single highest | |
| if len(predicted_dialects) == 0: | |
| max_idx = np.argmax(probs) | |
| predicted_dialects.append(class_names[max_idx]) | |
| dialects_out = ", ".join(predicted_dialects) | |
| return dialects_out, predicted_probs, cleaned_text | |
| # ── 4. Gradio UI with Enforced Dark Mode ── | |
| # The following script is a standard hack to forcefully enable dark mode on app load. | |
| dark_mode_js = """ | |
| function() { | |
| document.body.classList.add('dark'); | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Base(), title="Arabic Multi-Dialect Analyzer") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌙 Multi-Label Arabic Dialect Inference | |
| Identify overlapping dialects in modern Arabic text seamlessly. | |
| *Powered by MARBERTv2* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| text_input = gr.Textbox( | |
| lines=5, | |
| label="Arabic Text Input", | |
| placeholder="أدخل النص العربي هنا...", | |
| rtl=True | |
| ) | |
| threshold_slider = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| step=0.01, | |
| value=0.45, | |
| label="Confidence Threshold", | |
| info="Determines minimum probability needed to label a dialect." | |
| ) | |
| submit_btn = gr.Button("Analyze Text", variant="primary") | |
| with gr.Column(): | |
| dialects_output = gr.Textbox(label="Predicted Dialect(s)") | |
| prob_output = gr.Label(label="Confidence Probabilities", num_top_classes=10) | |
| clean_text_output = gr.Textbox(label="Text After Pre-Processing", rtl=True) | |
| # Optional Examples for fast-testing | |
| gr.Examples( | |
| examples=[ | |
| ["شلونك اليوم؟ شو عم تعمل؟", 0.45], | |
| ["انا رايح الشغل بدري علشان عندي شغل كتير.", 0.45] | |
| ], | |
| inputs=[text_input, threshold_slider] | |
| ) | |
| # Map button click to backend logic | |
| submit_btn.click( | |
| fn=predict_dialects, | |
| inputs=[text_input, threshold_slider], | |
| outputs=[dialects_output, prob_output, clean_text_output] | |
| ) | |
| # Inject JS at initialization to Force Dark Theme | |
| demo.load(None, None, None, js=dark_mode_js) | |
| if __name__ == "__main__": | |
| demo.launch() |