mahmoudmohammad commited on
Commit
24e4add
·
verified ·
1 Parent(s): bc4b957

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +242 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from tqdm import tqdm # Just for download progress bar
10
+
11
+ # ==========================================
12
+ # 1. CONFIGURATION
13
+ # ==========================================
14
+ MODEL_URL = "https://huggingface.co/datasets/mahmoudmohammad/Propaganda_Detection/resolve/main/paper_arch_asl_uw_marbertv2_raw-data.bin"
15
+ MODEL_FILENAME = "paper_arch_asl_uw_marbertv2_raw-data.bin"
16
+ MODEL_NAME = "UBC-NLP/MARBERTv2"
17
+ MAX_LEN = 256
18
+ TASK_EMBED_DIM = 128
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # --- Classes (Hardcoded as per your dataset) ---
22
+ PROP_CLASSES = [
23
+ 'Appeal to authority', 'Appeal to fear/prejudice', 'Appeal to time',
24
+ 'Bandwagon', 'Black-and-white Fallacy/Dictatorship',
25
+ 'Causal Oversimplification', 'Doubt', 'Exaggeration/Minimisation',
26
+ 'Flag-waving', 'Glittering generalities (Virtue)', 'Loaded Language',
27
+ "Misrepresentation of Someone's Position (Straw Man)",
28
+ 'Name calling/Labeling', 'Obfuscation, Intentional vagueness, Confusion',
29
+ 'Presenting Irrelevant Data (Red Herring)', 'Repetition', 'Slogans',
30
+ 'Smears', 'Thought-terminating cliché', 'Whataboutism'
31
+ ]
32
+
33
+ EMO_CLASSES = [
34
+ 'anger', 'annoyance', 'anticipation', 'anxiety', 'confusion', 'denial',
35
+ 'disgust', 'empathy', 'fear', 'gratitude', 'humor', 'joy', 'love',
36
+ 'neutral', 'optimism', 'pessimism', 'sadness', 'surprise',
37
+ 'sympathy', 'trust'
38
+ ]
39
+
40
+ # ==========================================
41
+ # 2. HELPER: DOWNLOADER
42
+ # ==========================================
43
+ def download_model_if_missing():
44
+ if not os.path.exists(MODEL_FILENAME):
45
+ print(f"📥 Model file not found. Downloading from Hugging Face...")
46
+ print(f" URL: {MODEL_URL}")
47
+
48
+ try:
49
+ response = requests.get(MODEL_URL, stream=True)
50
+ response.raise_for_status() # Check for error
51
+
52
+ total_size = int(response.headers.get('content-length', 0))
53
+ block_size = 1024 # 1 Kilobyte
54
+
55
+ with open(MODEL_FILENAME, "wb") as file, tqdm(
56
+ desc=MODEL_FILENAME,
57
+ total=total_size,
58
+ unit='iB',
59
+ unit_scale=True,
60
+ unit_divisor=1024,
61
+ ) as bar:
62
+ for data in response.iter_content(block_size):
63
+ size = file.write(data)
64
+ bar.update(size)
65
+ print("\n✅ Download complete.")
66
+
67
+ except Exception as e:
68
+ print(f"\n❌ Failed to download model: {e}")
69
+ raise
70
+ else:
71
+ print(f"✅ Model file '{MODEL_FILENAME}' already exists.")
72
+
73
+ # ==========================================
74
+ # 3. PREPROCESSING
75
+ # ==========================================
76
+ def preprocess_text(text):
77
+ if not isinstance(text, str): return ""
78
+ text = re.sub(r'http\S+|www\S+', '[URL]', text)
79
+ text = re.sub(r'@\w+', '[USER]', text)
80
+ text = re.sub(r'[a-zA-Z]', '', text)
81
+ text = re.sub("[إأآ]", "ا", text)
82
+ text = re.sub("ة", "ه", text)
83
+ text = re.sub("ى", "ي", text)
84
+ text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
85
+ return text.strip()
86
+
87
+ # ==========================================
88
+ # 4. MODEL ARCHITECTURE
89
+ # ==========================================
90
+ class AlHenakiMTLModel(nn.Module):
91
+ def __init__(self, n_propaganda, n_emotion):
92
+ super(AlHenakiMTLModel, self).__init__()
93
+ self.arabert = AutoModel.from_pretrained(MODEL_NAME)
94
+ self.hidden_size = 768
95
+ self.task_embedding = nn.Embedding(num_embeddings=2, embedding_dim=TASK_EMBED_DIM)
96
+ self.head_input_dim = self.hidden_size + TASK_EMBED_DIM
97
+ self.prop_head = nn.Linear(self.head_input_dim, n_propaganda)
98
+ self.emo_head = nn.Linear(self.head_input_dim, n_emotion)
99
+ # Weights used during training loss calc, kept here for structure compatibility
100
+ self.log_sigma_prop = nn.Parameter(torch.zeros(1))
101
+ self.log_sigma_emo = nn.Parameter(torch.zeros(1))
102
+
103
+ def forward(self, input_ids, attention_mask, task_ids):
104
+ outputs = self.arabert(input_ids, attention_mask=attention_mask)
105
+ pooled_output = outputs.pooler_output
106
+ t_embed = self.task_embedding(task_ids)
107
+ z = torch.cat((pooled_output, t_embed), dim=1)
108
+
109
+ current_task = task_ids[0].item()
110
+ if current_task == 0:
111
+ return self.prop_head(z), self.log_sigma_prop
112
+ elif current_task == 1:
113
+ return self.emo_head(z), self.log_sigma_emo
114
+ else:
115
+ raise ValueError("Unknown Task ID")
116
+
117
+ # ==========================================
118
+ # 5. INITIALIZE GLOBALS
119
+ # ==========================================
120
+ # 1. Download
121
+ download_model_if_missing()
122
+
123
+ # 2. Load Components
124
+ print("⏳ Loading Tokenizer & Model...")
125
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
126
+ model = AlHenakiMTLModel(len(PROP_CLASSES), len(EMO_CLASSES))
127
+
128
+ # 3. Load Weights
129
+ try:
130
+ state_dict = torch.load(MODEL_FILENAME, map_location=DEVICE)
131
+ model.load_state_dict(state_dict)
132
+ model.to(DEVICE)
133
+ model.eval()
134
+ print("✅ Model loaded successfully on", DEVICE)
135
+ except Exception as e:
136
+ print(f"❌ Critical Error Loading Model: {e}")
137
+
138
+ # ==========================================
139
+ # 6. INFERENCE LOGIC
140
+ # ==========================================
141
+ def predict_fn(text, threshold):
142
+ clean_text = preprocess_text(text)
143
+
144
+ # Empty check
145
+ if not clean_text.strip():
146
+ return {}, {}, "Please enter Arabic text."
147
+
148
+ # Tokenize
149
+ inputs = tokenizer(
150
+ clean_text,
151
+ return_tensors="pt",
152
+ max_length=MAX_LEN,
153
+ padding="max_length",
154
+ truncation=True
155
+ ).to(DEVICE)
156
+
157
+ input_ids = inputs['input_ids']
158
+ attn_mask = inputs['attention_mask']
159
+
160
+ with torch.no_grad():
161
+ # Propaganda
162
+ task_ids_p = torch.tensor([0] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
163
+ logits_p, _ = model(input_ids, attn_mask, task_ids_p)
164
+ probs_p = torch.sigmoid(logits_p).cpu().numpy()[0]
165
+
166
+ # Emotions
167
+ task_ids_e = torch.tensor([1] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
168
+ logits_e, _ = model(input_ids, attn_mask, task_ids_e)
169
+ probs_e = torch.sigmoid(logits_e).cpu().numpy()[0]
170
+
171
+ # Format for Gradio Label Output ({Label: Score})
172
+ # Filter by threshold AND convert numpy float to native float
173
+ prop_results = {
174
+ PROP_CLASSES[i]: float(probs_p[i])
175
+ for i in range(len(probs_p)) if probs_p[i] > threshold
176
+ }
177
+
178
+ emo_results = {
179
+ EMO_CLASSES[i]: float(probs_e[i])
180
+ for i in range(len(probs_e)) if probs_e[i] > threshold
181
+ }
182
+
183
+ return prop_results, emo_results, f"Processed: {len(clean_text)} chars"
184
+
185
+ # ==========================================
186
+ # 7. MODERN UI (GRADIO)
187
+ # ==========================================
188
+ custom_css = """
189
+ body { background-color: #f5f7fa; }
190
+ .container { max-width: 900px; margin: auto; padding-top: 20px; }
191
+ """
192
+
193
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AraProp Detector") as demo:
194
+ gr.Markdown(
195
+ """
196
+ # 🕵️‍♂️ Multi-Task Arabic Propaganda & Emotion Detector
197
+ ### Based on AraBERT-v02 | SOTA Reproduction
198
+ """
199
+ )
200
+
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
+ input_text = gr.Textbox(
204
+ lines=5,
205
+ placeholder="أدخل النص هنا للتحليل...",
206
+ label="Input Arabic Text",
207
+ value="يا له من عار! هذا السياسي يدمر البلاد بخططه الشيطانية الفاشلة."
208
+ )
209
+
210
+ threshold_slider = gr.Slider(
211
+ minimum=0.0,
212
+ maximum=1.0,
213
+ value=0.4,
214
+ step=0.05,
215
+ label="Confidence Threshold (Sensitivity)"
216
+ )
217
+
218
+ run_btn = gr.Button("Analyze Text 🚀", variant="primary")
219
+ status_box = gr.Markdown("Ready...")
220
+
221
+ with gr.Column(scale=1):
222
+ gr.Markdown("### 📊 Detection Results")
223
+ # We use 'Label' components which give nice progress bars
224
+ out_prop = gr.Label(num_top_classes=8, label="Propaganda Techniques")
225
+ out_emo = gr.Label(num_top_classes=8, label="Underlying Emotions")
226
+
227
+ # Connect components
228
+ run_btn.click(
229
+ fn=predict_fn,
230
+ inputs=[input_text, threshold_slider],
231
+ outputs=[out_prop, out_emo, status_box]
232
+ )
233
+
234
+ gr.Markdown("---")
235
+ gr.Markdown(f"Running on: {DEVICE} | Model: {MODEL_NAME}")
236
+
237
+ # ==========================================
238
+ # 8. LAUNCH
239
+ # ==========================================
240
+ if __name__ == "__main__":
241
+ # share=True creates a public link for RunPod/Colab
242
+ demo.launch(share=True, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ numpy
5
+ requests
6
+ tqdm