| import os |
| import re |
| import requests |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import gradio as gr |
| from transformers import AutoTokenizer, AutoModel |
| from tqdm import tqdm |
|
|
| |
| |
| |
| MODEL_URL = "https://huggingface.co/datasets/mahmoudmohammad/Propaganda_Detection/resolve/main/paper_arch_asl_uw_marbertv2_raw-data.bin" |
| MODEL_FILENAME = "paper_arch_asl_uw_marbertv2_raw-data.bin" |
| MODEL_NAME = "UBC-NLP/MARBERTv2" |
| MAX_LEN = 256 |
| TASK_EMBED_DIM = 128 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| PROP_CLASSES = [ |
| 'Appeal to authority', 'Appeal to fear/prejudice', 'Appeal to time', |
| 'Bandwagon', 'Black-and-white Fallacy/Dictatorship', |
| 'Causal Oversimplification', 'Doubt', 'Exaggeration/Minimisation', |
| 'Flag-waving', 'Glittering generalities (Virtue)', 'Loaded Language', |
| "Misrepresentation of Someone's Position (Straw Man)", |
| 'Name calling/Labeling', 'Obfuscation, Intentional vagueness, Confusion', |
| 'Presenting Irrelevant Data (Red Herring)', 'Repetition', 'Slogans', |
| 'Smears', 'Thought-terminating cliché', 'Whataboutism' |
| ] |
|
|
| EMO_CLASSES = [ |
| 'anger', 'annoyance', 'anticipation', 'anxiety', 'confusion', 'denial', |
| 'disgust', 'empathy', 'fear', 'gratitude', 'humor', 'joy', 'love', |
| 'neutral', 'optimism', 'pessimism', 'sadness', 'surprise', |
| 'sympathy', 'trust' |
| ] |
|
|
| |
| |
| |
| def download_model_if_missing(): |
| if not os.path.exists(MODEL_FILENAME): |
| print(f"📥 Model file not found. Downloading from Hugging Face...") |
| print(f" URL: {MODEL_URL}") |
| |
| try: |
| response = requests.get(MODEL_URL, stream=True) |
| response.raise_for_status() |
| |
| total_size = int(response.headers.get('content-length', 0)) |
| block_size = 1024 |
| |
| with open(MODEL_FILENAME, "wb") as file, tqdm( |
| desc=MODEL_FILENAME, |
| total=total_size, |
| unit='iB', |
| unit_scale=True, |
| unit_divisor=1024, |
| ) as bar: |
| for data in response.iter_content(block_size): |
| size = file.write(data) |
| bar.update(size) |
| print("\n✅ Download complete.") |
| |
| except Exception as e: |
| print(f"\n❌ Failed to download model: {e}") |
| raise |
| else: |
| print(f"✅ Model file '{MODEL_FILENAME}' already exists.") |
|
|
| |
| |
| |
| def preprocess_text(text): |
| if not isinstance(text, str): return "" |
| text = re.sub(r'http\S+|www\S+', '[URL]', text) |
| text = re.sub(r'@\w+', '[USER]', text) |
| text = re.sub(r'[a-zA-Z]', '', text) |
| text = re.sub("[إأآ]", "ا", text) |
| text = re.sub("ة", "ه", text) |
| text = re.sub("ى", "ي", text) |
| text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text) |
| return text.strip() |
|
|
| |
| |
| |
| class AlHenakiMTLModel(nn.Module): |
| def __init__(self, n_propaganda, n_emotion): |
| super(AlHenakiMTLModel, self).__init__() |
| self.arabert = AutoModel.from_pretrained(MODEL_NAME) |
| self.hidden_size = 768 |
| self.task_embedding = nn.Embedding(num_embeddings=2, embedding_dim=TASK_EMBED_DIM) |
| self.head_input_dim = self.hidden_size + TASK_EMBED_DIM |
| self.prop_head = nn.Linear(self.head_input_dim, n_propaganda) |
| self.emo_head = nn.Linear(self.head_input_dim, n_emotion) |
| |
| self.log_sigma_prop = nn.Parameter(torch.zeros(1)) |
| self.log_sigma_emo = nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, input_ids, attention_mask, task_ids): |
| outputs = self.arabert(input_ids, attention_mask=attention_mask) |
| pooled_output = outputs.pooler_output |
| t_embed = self.task_embedding(task_ids) |
| z = torch.cat((pooled_output, t_embed), dim=1) |
| |
| current_task = task_ids[0].item() |
| if current_task == 0: |
| return self.prop_head(z), self.log_sigma_prop |
| elif current_task == 1: |
| return self.emo_head(z), self.log_sigma_emo |
| else: |
| raise ValueError("Unknown Task ID") |
|
|
| |
| |
| |
| |
| download_model_if_missing() |
|
|
| |
| print("⏳ Loading Tokenizer & Model...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AlHenakiMTLModel(len(PROP_CLASSES), len(EMO_CLASSES)) |
|
|
| |
| try: |
| state_dict = torch.load(MODEL_FILENAME, map_location=DEVICE) |
| model.load_state_dict(state_dict) |
| model.to(DEVICE) |
| model.eval() |
| print("✅ Model loaded successfully on", DEVICE) |
| except Exception as e: |
| print(f"❌ Critical Error Loading Model: {e}") |
|
|
| |
| |
| |
| def predict_fn(text, threshold): |
| clean_text = preprocess_text(text) |
| |
| |
| if not clean_text.strip(): |
| return {}, {}, "Please enter Arabic text." |
|
|
| |
| inputs = tokenizer( |
| clean_text, |
| return_tensors="pt", |
| max_length=MAX_LEN, |
| padding="max_length", |
| truncation=True |
| ).to(DEVICE) |
| |
| input_ids = inputs['input_ids'] |
| attn_mask = inputs['attention_mask'] |
|
|
| with torch.no_grad(): |
| |
| task_ids_p = torch.tensor([0] * input_ids.shape[0], dtype=torch.long).to(DEVICE) |
| logits_p, _ = model(input_ids, attn_mask, task_ids_p) |
| probs_p = torch.sigmoid(logits_p).cpu().numpy()[0] |
|
|
| |
| task_ids_e = torch.tensor([1] * input_ids.shape[0], dtype=torch.long).to(DEVICE) |
| logits_e, _ = model(input_ids, attn_mask, task_ids_e) |
| probs_e = torch.sigmoid(logits_e).cpu().numpy()[0] |
|
|
| |
| |
| prop_results = { |
| PROP_CLASSES[i]: float(probs_p[i]) |
| for i in range(len(probs_p)) if probs_p[i] > threshold |
| } |
| |
| emo_results = { |
| EMO_CLASSES[i]: float(probs_e[i]) |
| for i in range(len(probs_e)) if probs_e[i] > threshold |
| } |
| |
| return prop_results, emo_results, f"Processed: {len(clean_text)} chars" |
|
|
| |
| |
| |
| custom_css = """ |
| .container { max-width: 900px; margin: auto; padding-top: 20px; } |
| """ |
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AraProp Detector", js="() => document.body.classList.add('dark')") as demo: |
| gr.Markdown( |
| """ |
| # 🕵️♂️ Multi-Task Arabic Propaganda & Emotion Detector |
| ### Based on AraBERT-v02 | SOTA Reproduction |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_text = gr.Textbox( |
| lines=5, |
| placeholder="أدخل النص هنا للتحليل...", |
| label="Input Arabic Text", |
| value="يا له من عار! هذا السياسي يدمر البلاد بخططه الشيطانية الفاشلة." |
| ) |
| |
| threshold_slider = gr.Slider( |
| minimum=0.0, |
| maximum=1.0, |
| value=0.4, |
| step=0.05, |
| label="Confidence Threshold (Sensitivity)" |
| ) |
| |
| run_btn = gr.Button("Analyze Text 🚀", variant="primary") |
| status_box = gr.Markdown("Ready...") |
|
|
| with gr.Column(scale=1): |
| gr.Markdown("### 📊 Detection Results") |
| |
| out_prop = gr.Label(num_top_classes=8, label="Propaganda Techniques") |
| out_emo = gr.Label(num_top_classes=8, label="Underlying Emotions") |
|
|
| |
| run_btn.click( |
| fn=predict_fn, |
| inputs=[input_text, threshold_slider], |
| outputs=[out_prop, out_emo, status_box] |
| ) |
|
|
| gr.Markdown("---") |
| gr.Markdown(f"Running on: {DEVICE} | Model: {MODEL_NAME}") |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| |
| demo.launch(share=True, show_error=True) |