mahmoudmohammad's picture
Update app.py
bea32d7 verified
import os
import re
import requests
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm # Just for download progress bar
# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_URL = "https://huggingface.co/datasets/mahmoudmohammad/Propaganda_Detection/resolve/main/paper_arch_asl_uw_marbertv2_raw-data.bin"
MODEL_FILENAME = "paper_arch_asl_uw_marbertv2_raw-data.bin"
MODEL_NAME = "UBC-NLP/MARBERTv2"
MAX_LEN = 256
TASK_EMBED_DIM = 128
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Classes (Hardcoded as per your dataset) ---
PROP_CLASSES = [
'Appeal to authority', 'Appeal to fear/prejudice', 'Appeal to time',
'Bandwagon', 'Black-and-white Fallacy/Dictatorship',
'Causal Oversimplification', 'Doubt', 'Exaggeration/Minimisation',
'Flag-waving', 'Glittering generalities (Virtue)', 'Loaded Language',
"Misrepresentation of Someone's Position (Straw Man)",
'Name calling/Labeling', 'Obfuscation, Intentional vagueness, Confusion',
'Presenting Irrelevant Data (Red Herring)', 'Repetition', 'Slogans',
'Smears', 'Thought-terminating cliché', 'Whataboutism'
]
EMO_CLASSES = [
'anger', 'annoyance', 'anticipation', 'anxiety', 'confusion', 'denial',
'disgust', 'empathy', 'fear', 'gratitude', 'humor', 'joy', 'love',
'neutral', 'optimism', 'pessimism', 'sadness', 'surprise',
'sympathy', 'trust'
]
# ==========================================
# 2. HELPER: DOWNLOADER
# ==========================================
def download_model_if_missing():
if not os.path.exists(MODEL_FILENAME):
print(f"📥 Model file not found. Downloading from Hugging Face...")
print(f" URL: {MODEL_URL}")
try:
response = requests.get(MODEL_URL, stream=True)
response.raise_for_status() # Check for error
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kilobyte
with open(MODEL_FILENAME, "wb") as file, tqdm(
desc=MODEL_FILENAME,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(block_size):
size = file.write(data)
bar.update(size)
print("\n✅ Download complete.")
except Exception as e:
print(f"\n❌ Failed to download model: {e}")
raise
else:
print(f"✅ Model file '{MODEL_FILENAME}' already exists.")
# ==========================================
# 3. PREPROCESSING
# ==========================================
def preprocess_text(text):
if not isinstance(text, str): return ""
text = re.sub(r'http\S+|www\S+', '[URL]', text)
text = re.sub(r'@\w+', '[USER]', text)
text = re.sub(r'[a-zA-Z]', '', text)
text = re.sub("[إأآ]", "ا", text)
text = re.sub("ة", "ه", text)
text = re.sub("ى", "ي", text)
text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
return text.strip()
# ==========================================
# 4. MODEL ARCHITECTURE
# ==========================================
class AlHenakiMTLModel(nn.Module):
def __init__(self, n_propaganda, n_emotion):
super(AlHenakiMTLModel, self).__init__()
self.arabert = AutoModel.from_pretrained(MODEL_NAME)
self.hidden_size = 768
self.task_embedding = nn.Embedding(num_embeddings=2, embedding_dim=TASK_EMBED_DIM)
self.head_input_dim = self.hidden_size + TASK_EMBED_DIM
self.prop_head = nn.Linear(self.head_input_dim, n_propaganda)
self.emo_head = nn.Linear(self.head_input_dim, n_emotion)
# Weights used during training loss calc, kept here for structure compatibility
self.log_sigma_prop = nn.Parameter(torch.zeros(1))
self.log_sigma_emo = nn.Parameter(torch.zeros(1))
def forward(self, input_ids, attention_mask, task_ids):
outputs = self.arabert(input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
t_embed = self.task_embedding(task_ids)
z = torch.cat((pooled_output, t_embed), dim=1)
current_task = task_ids[0].item()
if current_task == 0:
return self.prop_head(z), self.log_sigma_prop
elif current_task == 1:
return self.emo_head(z), self.log_sigma_emo
else:
raise ValueError("Unknown Task ID")
# ==========================================
# 5. INITIALIZE GLOBALS
# ==========================================
# 1. Download
download_model_if_missing()
# 2. Load Components
print("⏳ Loading Tokenizer & Model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AlHenakiMTLModel(len(PROP_CLASSES), len(EMO_CLASSES))
# 3. Load Weights
try:
state_dict = torch.load(MODEL_FILENAME, map_location=DEVICE)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print("✅ Model loaded successfully on", DEVICE)
except Exception as e:
print(f"❌ Critical Error Loading Model: {e}")
# ==========================================
# 6. INFERENCE LOGIC
# ==========================================
def predict_fn(text, threshold):
clean_text = preprocess_text(text)
# Empty check
if not clean_text.strip():
return {}, {}, "Please enter Arabic text."
# Tokenize
inputs = tokenizer(
clean_text,
return_tensors="pt",
max_length=MAX_LEN,
padding="max_length",
truncation=True
).to(DEVICE)
input_ids = inputs['input_ids']
attn_mask = inputs['attention_mask']
with torch.no_grad():
# Propaganda
task_ids_p = torch.tensor([0] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
logits_p, _ = model(input_ids, attn_mask, task_ids_p)
probs_p = torch.sigmoid(logits_p).cpu().numpy()[0]
# Emotions
task_ids_e = torch.tensor([1] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
logits_e, _ = model(input_ids, attn_mask, task_ids_e)
probs_e = torch.sigmoid(logits_e).cpu().numpy()[0]
# Format for Gradio Label Output ({Label: Score})
# Filter by threshold AND convert numpy float to native float
prop_results = {
PROP_CLASSES[i]: float(probs_p[i])
for i in range(len(probs_p)) if probs_p[i] > threshold
}
emo_results = {
EMO_CLASSES[i]: float(probs_e[i])
for i in range(len(probs_e)) if probs_e[i] > threshold
}
return prop_results, emo_results, f"Processed: {len(clean_text)} chars"
# ==========================================
# 7. MODERN UI (GRADIO)
# ==========================================
custom_css = """
.container { max-width: 900px; margin: auto; padding-top: 20px; }
"""
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AraProp Detector", js="() => document.body.classList.add('dark')") as demo:
gr.Markdown(
"""
# 🕵️‍♂️ Multi-Task Arabic Propaganda & Emotion Detector
### Based on AraBERT-v02 | SOTA Reproduction
"""
)
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
lines=5,
placeholder="أدخل النص هنا للتحليل...",
label="Input Arabic Text",
value="يا له من عار! هذا السياسي يدمر البلاد بخططه الشيطانية الفاشلة."
)
threshold_slider = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.4,
step=0.05,
label="Confidence Threshold (Sensitivity)"
)
run_btn = gr.Button("Analyze Text 🚀", variant="primary")
status_box = gr.Markdown("Ready...")
with gr.Column(scale=1):
gr.Markdown("### 📊 Detection Results")
# We use 'Label' components which give nice progress bars
out_prop = gr.Label(num_top_classes=8, label="Propaganda Techniques")
out_emo = gr.Label(num_top_classes=8, label="Underlying Emotions")
# Connect components
run_btn.click(
fn=predict_fn,
inputs=[input_text, threshold_slider],
outputs=[out_prop, out_emo, status_box]
)
gr.Markdown("---")
gr.Markdown(f"Running on: {DEVICE} | Model: {MODEL_NAME}")
# ==========================================
# 8. LAUNCH
# ==========================================
if __name__ == "__main__":
# share=True creates a public link for RunPod/Colab
demo.launch(share=True, show_error=True)