| import gradio as gr |
| import joblib |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import pandas as pd |
| from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification |
| from sklearn.preprocessing import LabelEncoder |
| from huggingface_hub import hf_hub_download |
| import re |
| import pyarabic.araby as araby |
|
|
| |
| MODEL_NAME = "aubmindlab/bert-base-arabertv02" |
| MODEL_HUB = "batool0/arabic-speech-act-models" |
| MAX_LEN = 64 |
| CLASSES = ['Assertion', 'Expression', 'Question', 'Recommendation', 'Request'] |
| CLASS_AR = { |
| 'Assertion': 'تأكيد', |
| 'Expression': 'تعبير', |
| 'Question': 'سؤال', |
| 'Recommendation': 'توصية', |
| 'Request': 'طلب' |
| } |
| ERROR_PATTERNS = { |
| ('Expression', 'Assertion'): "Expression/Assertion boundary: tweets describing events with emotional tone are often misclassified as Assertion.", |
| ('Assertion', 'Expression'): "Expression/Assertion boundary: factual tweets with emotional vocabulary are sometimes misclassified as Expression.", |
| ('Question', 'Expression'): "Implicit question: this question lacks an explicit interrogative particle (هل، ماذا), causing it to resemble an Expression.", |
| ('Expression', 'Question'): "Implicit question: emotionally phrased tweet contains question-like vocabulary.", |
| ('Request', 'Assertion'): "Analytical request: the request is framed as a logical argument, resembling an Assertion.", |
| ('Request', 'Recommendation'): "Request vs Recommendation: the boundary between requesting and recommending is thin in Arabic.", |
| ('Recommendation', 'Expression'): "Sarcastic recommendation misread as emotional Expression.", |
| ('Assertion', 'Question'): "Assertion/Question boundary: the tweet may contain an implicit question structure without explicit particles.", |
| ('Question', 'Assertion'): "Question/Assertion boundary: the question is phrased as a statement, common in Arabic rhetorical questions.", |
| } |
|
|
| CLASS_DESCRIPTIONS = { |
| 'Assertion': "states a fact or conveys information objectively.", |
| 'Expression': "expresses an opinion, emotion, or personal feeling.", |
| 'Question': "asks for information or seeks clarification.", |
| 'Recommendation': "suggests or advises a course of action.", |
| 'Request': "asks someone to do something or take action.", |
| } |
|
|
| |
| def clean_text(text): |
| text = re.sub(r'http\S+|www\S+', '', text) |
| text = re.sub(r'@\w+', '', text) |
| text = re.sub(r'#\w+', '', text) |
| text = re.sub(r'\d+', '', text) |
| text = re.sub(r'[^\w\s\u0600-\u06FF]', '', text) |
| text = araby.strip_tashkeel(text) |
| text = araby.strip_tatweel(text) |
| text = re.sub(r'[إأآا]', 'ا', text) |
| text = re.sub(r'ة', 'ه', text) |
| text = re.sub(r'ى', 'ي', text) |
| text = re.sub(r'\s+', ' ', text).strip() |
| return text |
|
|
| |
| class AraBERTBiLSTM(nn.Module): |
| def __init__(self, bert_model_name, hidden_dim, num_layers, num_classes, dropout=0.3): |
| super(AraBERTBiLSTM, self).__init__() |
| self.bert = AutoModel.from_pretrained(bert_model_name) |
| for param in self.bert.parameters(): |
| param.requires_grad = False |
| bert_hidden_size = self.bert.config.hidden_size |
| self.bilstm = nn.LSTM( |
| input_size=bert_hidden_size, |
| hidden_size=hidden_dim, |
| num_layers=num_layers, |
| batch_first=True, |
| bidirectional=True, |
| dropout=dropout if num_layers > 1 else 0.0 |
| ) |
| self.dropout = nn.Dropout(dropout) |
| self.classifier = nn.Linear(hidden_dim * 2, num_classes) |
|
|
| def forward(self, input_ids, attention_mask): |
| with torch.no_grad(): |
| bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| token_embeddings = bert_output.last_hidden_state |
| lstm_output, _ = self.bilstm(token_embeddings) |
| mask = attention_mask.unsqueeze(-1).float() |
| pooled = (lstm_output * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) |
| pooled = self.dropout(pooled) |
| return self.classifier(pooled) |
|
|
| |
| class AraBERTClassifier(nn.Module): |
| def __init__(self, model_name, num_classes, dropout=0.3): |
| super(AraBERTClassifier, self).__init__() |
| self.bert = AutoModel.from_pretrained(model_name) |
| self.dropout = nn.Dropout(dropout) |
| self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes) |
|
|
| def forward(self, input_ids, attention_mask): |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| cls_output = output.last_hidden_state[:, 0, :] |
| cls_output = self.dropout(cls_output) |
| logits = self.classifier(cls_output) |
| return logits |
|
|
| |
| print("Loading models...") |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| le = LabelEncoder() |
| le.fit(CLASSES) |
|
|
| svm_model = joblib.load('svm_model.pkl') |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| bilstm_path = hf_hub_download(repo_id=MODEL_HUB, filename='best_bilstm_arabert.pt') |
| bilstm_model = AraBERTBiLSTM(MODEL_NAME, hidden_dim=128, num_layers=2, num_classes=5) |
| bilstm_model.load_state_dict(torch.load(bilstm_path, map_location=device)) |
| bilstm_model.to(device) |
| bilstm_model.eval() |
|
|
| arabert_path = hf_hub_download(repo_id=MODEL_HUB, filename='best_arabert.pt') |
| arabert_model = AraBERTClassifier(MODEL_NAME, num_classes=len(CLASSES), dropout=0.1).to(device) |
| arabert_model.load_state_dict(torch.load(arabert_path, map_location=device)) |
| arabert_model.eval() |
|
|
| test_df = pd.read_csv('test_with_labels.csv') |
| print("All models loaded") |
|
|
| |
| def predict_svm(text): |
| scores = svm_model.decision_function([text])[0] |
| scores = (scores - scores.min()) / (scores.max() - scores.min() + 1e-9) |
| pred_idx = scores.argmax() |
| pred_class = svm_model.classes_[pred_idx] |
| return pred_class, dict(zip(svm_model.classes_, scores.tolist())) |
|
|
| def predict_bilstm(text): |
| enc = tokenizer(text, max_length=MAX_LEN, padding='max_length', |
| truncation=True, return_tensors='pt') |
| with torch.no_grad(): |
| logits = bilstm_model(enc['input_ids'].to(device), enc['attention_mask'].to(device)) |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] |
| pred_class = le.classes_[probs.argmax()] |
| return pred_class, dict(zip(le.classes_, probs.tolist())) |
|
|
| def predict_arabert(text): |
| enc = tokenizer(text, max_length=MAX_LEN, padding='max_length', |
| truncation=True, return_tensors='pt') |
| with torch.no_grad(): |
| logits = arabert_model(input_ids=enc['input_ids'].to(device), |
| attention_mask=enc['attention_mask'].to(device)) |
| probs = torch.softmax(logits, dim=1).cpu().numpy()[0] |
| pred_class = le.classes_[probs.argmax()] |
| return pred_class, dict(zip(le.classes_, probs.tolist())) |
|
|
| |
| def get_ground_truth(text): |
| cleaned = clean_text(text) |
| match = test_df[test_df['text'] == cleaned] |
| if len(match) > 0: |
| return match.iloc[0]['label'] |
| return None |
|
|
| |
| def get_error_analysis(true_label, pred_label): |
| key = (true_label, pred_label) |
| return ERROR_PATTERNS.get(key, f"The model predicted {pred_label} instead of {true_label}. This may reflect lexical overlap between these classes in Arabic social media text.") |
|
|
| |
| def get_smart_analysis(svm_pred, bilstm_pred, ara_pred, svm_probs, bilstm_probs, ara_probs): |
| predictions = [svm_pred, bilstm_pred, ara_pred] |
| unique_preds = set(predictions) |
|
|
| if len(unique_preds) == 1: |
| pred = ara_pred |
| ara_conf = round(ara_probs.get(pred, 0) * 100) |
| if ara_conf >= 80: |
| return {'type': 'agree_high', 'message': f"All 3 models confidently agree: this tweet <b style='color:inherit;'>{CLASS_DESCRIPTIONS.get(pred, '')}</b> ({ara_conf}% confidence by AraBERT). This is an unambiguous case."} |
| else: |
| return {'type': 'agree_low', 'message': f"All 3 models agree on <b style='color:inherit;'>{pred}</b>, but with moderate confidence ({ara_conf}%). The tweet may have overlapping features with other classes."} |
|
|
| if len(unique_preds) == 2: |
| majority = max(set(predictions), key=predictions.count) |
| minority_model = None |
| minority_pred = None |
| for model_name, pred in [("SVM", svm_pred), ("BiLSTM", bilstm_pred), ("AraBERT", ara_pred)]: |
| if pred != majority: |
| minority_model = model_name |
| minority_pred = pred |
| pattern_key = (majority, minority_pred) |
| pattern_explanation = ERROR_PATTERNS.get(pattern_key, f"The boundary between {majority} and {minority_pred} can be ambiguous in Arabic social media text.") |
| return {'type': 'partial_disagree', 'message': f"2 models agree on <b style='color:inherit;'>{majority}</b> while {minority_model} predicts <b style='color:inherit;'>{minority_pred}</b>. {pattern_explanation}"} |
|
|
| return {'type': 'full_disagree', 'message': f"All 3 models disagree — SVM: <b style='color:inherit;'>{svm_pred}</b>, BiLSTM: <b style='color:inherit;'>{bilstm_pred}</b>, AraBERT: <b style='color:inherit;'>{ara_pred}</b>. This tweet is inherently ambiguous, likely due to mixed communicative intent, sarcasm, or dialectal phrasing."} |
|
|
| |
| def get_top_features(text, pred_class): |
| vec = svm_model.named_steps['tfidf'] |
| clf = svm_model.named_steps['svm'] |
| feature_names = vec.get_feature_names_out() |
| transformed = vec.transform([text]) |
| class_idx = list(svm_model.classes_).index(pred_class) |
| scores = transformed.toarray()[0] * clf.coef_[class_idx] |
| top_idx = scores.argsort()[-5:][::-1] |
| return [feature_names[i] for i in top_idx if scores[i] > 0] |
|
|
| |
| def classify(text): |
| if not text.strip(): |
| return "<p style='color:#555;font-family:sans-serif;'>Please enter an Arabic tweet.</p>" |
|
|
| cleaned = clean_text(text) |
|
|
| svm_pred, svm_probs = predict_svm(cleaned) |
| bilstm_pred, bilstm_probs = predict_bilstm(cleaned) |
| ara_pred, ara_probs = predict_arabert(cleaned) |
| ground_truth = get_ground_truth(cleaned) |
| top_features = get_top_features(cleaned, svm_pred) |
|
|
| def conf(probs, cls): |
| return round(probs.get(cls, 0) * 100) |
|
|
| def verdict(pred, gt): |
| if gt is None: return "" |
| return "Correct" if pred == gt else f"Wrong — True: {gt}" |
|
|
| svm_v = verdict(svm_pred, ground_truth) |
| bilstm_v = verdict(bilstm_pred, ground_truth) |
| ara_v = verdict(ara_pred, ground_truth) |
|
|
| |
| |
| card_bg = "#ffffff" |
| card_text = "#111111" |
| card_sub = "#555555" |
| card_conf = "#333333" |
|
|
| features_bg = "#d4edda" |
| features_text = "#155724" |
| features_title= "#0a4520" |
|
|
| bar_bg = "#dddddd" |
| bar_fill = "#2563eb" |
| bar_label = "#222222" |
| bar_pct = "#444444" |
|
|
| breakdown_title = "#333333" |
|
|
| correct_color = "#15803d" |
| wrong_color = "#dc2626" |
|
|
| features_html = " ".join([ |
| f"<span style='background:#d4edda;color:#155724;padding:3px 10px;border-radius:20px;font-size:13px;font-weight:500;'>{w}</span>" |
| for w in top_features |
| ]) or "<span style='color:#555;'>—</span>" |
|
|
| |
| if ground_truth: |
| errors = [] |
| for model_name, pred in [("SVM", svm_pred), ("BiLSTM", bilstm_pred), ("AraBERT", ara_pred)]: |
| if pred != ground_truth: |
| analysis = get_error_analysis(ground_truth, pred) |
| errors.append(f"<b style='color:#111111;'>{model_name}:</b> <span style='color:#111111;'>{analysis}</span>") |
| if errors: |
| error_html = "<br>".join(errors) |
| analysis_section = f""" |
| <div style='margin-top:16px;padding:14px;background:#fde8e8;border-left:4px solid #dc2626;border-radius:8px;'> |
| <b style='color:#991b1b;font-size:13px;'>Error analysis — ground truth: <span style='color:#111111;'>{ground_truth}</span></b><br> |
| <span style='font-size:13px;color:#111111;line-height:1.6;'>{error_html}</span> |
| </div>""" |
| else: |
| analysis_section = f""" |
| <div style='margin-top:16px;padding:14px;background:#dcfce7;border-left:4px solid #16a34a;border-radius:8px;'> |
| <b style='color:#15803d;font-size:13px;'>All models correct</b> |
| <span style='color:#111111;font-size:13px;'> — ground truth: <b style='color:#111111;'>{ground_truth}</b>. Unambiguous tweet, all 3 models agree.</span> |
| </div>""" |
| else: |
| smart = get_smart_analysis(svm_pred, bilstm_pred, ara_pred, svm_probs, bilstm_probs, ara_probs) |
| styles = { |
| 'agree_high': ('#dcfce7', '#15803d', '#16a34a'), |
| 'agree_low': ('#fef9c3', '#854d0e', '#a16207'), |
| 'partial_disagree': ('#ffedd5', '#9a3412', '#c2410c'), |
| 'full_disagree': ('#fde8e8', '#991b1b', '#dc2626'), |
| } |
| bg, text_col, title_col = styles.get(smart['type'], ('#f5f5f5', '#333', '#555')) |
| border_colors = { |
| 'agree_high': '#16a34a', |
| 'agree_low': '#a16207', |
| 'partial_disagree': '#c2410c', |
| 'full_disagree': '#dc2626', |
| } |
| border_col = border_colors.get(smart['type'], '#888') |
| analysis_section = f""" |
| <div style='margin-top:16px;padding:14px;background:{bg};border-left:4px solid {border_col};border-radius:8px;'> |
| <b style='color:{title_col};font-size:13px;'>Model analysis — new tweet</b><br> |
| <span style='font-size:13px;color:{text_col};line-height:1.6;'>{smart['message']}</span> |
| </div>""" |
|
|
| bars_html = "" |
| for cls in CLASSES: |
| pct = conf(ara_probs, cls) |
| bars_html += f""" |
| <div style='display:flex;align-items:center;gap:10px;margin-bottom:7px;'> |
| <span style='font-size:12px;width:130px;color:{bar_label};font-weight:500;'>{cls}</span> |
| <div style='flex:1;background:{bar_bg};border-radius:4px;height:7px;'> |
| <div style='background:{bar_fill};width:{pct}%;height:7px;border-radius:4px;'></div> |
| </div> |
| <span style='font-size:12px;width:36px;text-align:right;color:{bar_pct};font-weight:600;'>{pct}%</span> |
| </div>""" |
|
|
| html = f""" |
| <div style='font-family:sans-serif;max-width:660px;'> |
| |
| <div style='display:grid;grid-template-columns:repeat(3,1fr);gap:10px;margin-bottom:16px;'> |
| |
| <div style='background:{card_bg};border:1.5px solid #6ee7b7;border-radius:10px;padding:14px;box-shadow:0 1px 4px rgba(0,0,0,0.08);'> |
| <p style='font-size:11px;color:#059669;margin:0 0 6px;font-weight:700;letter-spacing:0.3px;'>SVM + TF-IDF</p> |
| <p style='font-size:17px;font-weight:700;margin:0;color:{card_text};'>{svm_pred}</p> |
| <p style='font-size:11px;color:{card_sub};margin:2px 0 6px;'>{CLASS_AR.get(svm_pred,'')}</p> |
| <p style='font-size:12px;margin:0;color:{card_conf};'>{conf(svm_probs, svm_pred)}% confidence</p> |
| <p style='font-size:11px;margin:5px 0 0;font-weight:600;color:{"" + correct_color if "Correct" in svm_v else wrong_color};'>{svm_v}</p> |
| </div> |
| |
| <div style='background:{card_bg};border:2px solid #a5b4fc;border-radius:10px;padding:14px;box-shadow:0 1px 4px rgba(0,0,0,0.08);'> |
| <p style='font-size:11px;color:#4f46e5;margin:0 0 6px;font-weight:700;letter-spacing:0.3px;'>BiLSTM</p> |
| <p style='font-size:17px;font-weight:700;margin:0;color:{card_text};'>{bilstm_pred}</p> |
| <p style='font-size:11px;color:{card_sub};margin:2px 0 6px;'>{CLASS_AR.get(bilstm_pred,'')}</p> |
| <p style='font-size:12px;margin:0;color:{card_conf};'>{conf(bilstm_probs, bilstm_pred)}% confidence</p> |
| <p style='font-size:11px;margin:5px 0 0;font-weight:600;color:{"" + correct_color if "Correct" in bilstm_v else wrong_color};'>{bilstm_v}</p> |
| </div> |
| |
| <div style='background:{card_bg};border:1.5px solid #93c5fd;border-radius:10px;padding:14px;box-shadow:0 1px 4px rgba(0,0,0,0.08);'> |
| <p style='font-size:11px;color:#1d4ed8;margin:0 0 6px;font-weight:700;letter-spacing:0.3px;'>AraBERT</p> |
| <p style='font-size:17px;font-weight:700;margin:0;color:{card_text};'>{ara_pred}</p> |
| <p style='font-size:11px;color:{card_sub};margin:2px 0 6px;'>{CLASS_AR.get(ara_pred,'')}</p> |
| <p style='font-size:12px;margin:0;color:{card_conf};'>{conf(ara_probs, ara_pred)}% confidence</p> |
| <p style='font-size:11px;margin:5px 0 0;font-weight:600;color:{"" + correct_color if "Correct" in ara_v else wrong_color};'>{ara_v}</p> |
| </div> |
| |
| </div> |
| |
| <div style='padding:14px;background:{features_bg};border-radius:8px;margin-bottom:14px;'> |
| <p style='font-size:12px;color:{features_title};margin:0 0 8px;font-weight:700;'>Top signals </p> |
| {features_html} |
| </div> |
| |
| <div style='margin-bottom:14px;background:#ffffff;padding:12px;border-radius:8px;box-shadow:0 1px 4px rgba(0,0,0,0.06);'> |
| <p style='font-size:12px;color:{breakdown_title};margin:0 0 10px;font-weight:700;'>All classes — confidence breakdown </p> |
| {bars_html} |
| </div> |
| |
| {analysis_section} |
| </div> |
| """ |
| return html |
|
|
| |
| demo = gr.Interface( |
| fn=classify, |
| inputs=gr.Textbox( |
| label="Enter an Arabic tweet", |
| placeholder="مثال: الـ AraBERT فهم التغريدات العربية احسن مني انا ههههههههه", |
| rtl=True, |
| lines=3 |
| ), |
| outputs=gr.HTML(label="Results"), |
| title="Arabic Dialogue/Speech Act Classifier", |
| description="AI 445 — NLP Project | Jordan University of Science and Technology", |
| examples=[ |
| [""], |
| ["الأكل كان رائع جداً!"], |
| ["مين المسؤول ان الصوت بيقطع و مش ماشي مع كلام الرئيس اودام العالم كله"], |
| ["رئيس الجمهوريه التونسيه حاضرا مباراه بلاده في تصفيات كاس العالم"], |
| ["المشروع سينتهي غداً"], |
| ["ليش ال Recommendation غالبا بفشل؟ لانه مثلي ماحدا بسمعه"], |
| ["ماذا قال محمد صلاح عن اداء وتاهل تونس والمغرب الي المونديال"], |
| ["عندي اقتراح للشيخ عزمي بشاره بما ان رايه صائب الي هذه الدرجه ان يجلس مع الشيخ تميم ويوضعوا خطه محكمه لاعاده فلسطين او اعاده القدس ويتركوا الربيع العربي مؤقتا"], |
| ["رياضه محمد صلاح ينافس نجوم علي جائزه BBC للافضل في افريقيا"] |
| ], |
| flagging_mode="never" |
| ) |
|
|
| demo.launch() |
|
|