mahmoudmohammad commited on
Commit
bea32d7
·
verified ·
1 Parent(s): 24e4add

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -241
app.py CHANGED
@@ -1,242 +1,241 @@
1
- import os
2
- import re
3
- import requests
4
- import torch
5
- import torch.nn as nn
6
- import numpy as np
7
- import gradio as gr
8
- from transformers import AutoTokenizer, AutoModel
9
- from tqdm import tqdm # Just for download progress bar
10
-
11
- # ==========================================
12
- # 1. CONFIGURATION
13
- # ==========================================
14
- MODEL_URL = "https://huggingface.co/datasets/mahmoudmohammad/Propaganda_Detection/resolve/main/paper_arch_asl_uw_marbertv2_raw-data.bin"
15
- MODEL_FILENAME = "paper_arch_asl_uw_marbertv2_raw-data.bin"
16
- MODEL_NAME = "UBC-NLP/MARBERTv2"
17
- MAX_LEN = 256
18
- TASK_EMBED_DIM = 128
19
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
-
21
- # --- Classes (Hardcoded as per your dataset) ---
22
- PROP_CLASSES = [
23
- 'Appeal to authority', 'Appeal to fear/prejudice', 'Appeal to time',
24
- 'Bandwagon', 'Black-and-white Fallacy/Dictatorship',
25
- 'Causal Oversimplification', 'Doubt', 'Exaggeration/Minimisation',
26
- 'Flag-waving', 'Glittering generalities (Virtue)', 'Loaded Language',
27
- "Misrepresentation of Someone's Position (Straw Man)",
28
- 'Name calling/Labeling', 'Obfuscation, Intentional vagueness, Confusion',
29
- 'Presenting Irrelevant Data (Red Herring)', 'Repetition', 'Slogans',
30
- 'Smears', 'Thought-terminating cliché', 'Whataboutism'
31
- ]
32
-
33
- EMO_CLASSES = [
34
- 'anger', 'annoyance', 'anticipation', 'anxiety', 'confusion', 'denial',
35
- 'disgust', 'empathy', 'fear', 'gratitude', 'humor', 'joy', 'love',
36
- 'neutral', 'optimism', 'pessimism', 'sadness', 'surprise',
37
- 'sympathy', 'trust'
38
- ]
39
-
40
- # ==========================================
41
- # 2. HELPER: DOWNLOADER
42
- # ==========================================
43
- def download_model_if_missing():
44
- if not os.path.exists(MODEL_FILENAME):
45
- print(f"📥 Model file not found. Downloading from Hugging Face...")
46
- print(f" URL: {MODEL_URL}")
47
-
48
- try:
49
- response = requests.get(MODEL_URL, stream=True)
50
- response.raise_for_status() # Check for error
51
-
52
- total_size = int(response.headers.get('content-length', 0))
53
- block_size = 1024 # 1 Kilobyte
54
-
55
- with open(MODEL_FILENAME, "wb") as file, tqdm(
56
- desc=MODEL_FILENAME,
57
- total=total_size,
58
- unit='iB',
59
- unit_scale=True,
60
- unit_divisor=1024,
61
- ) as bar:
62
- for data in response.iter_content(block_size):
63
- size = file.write(data)
64
- bar.update(size)
65
- print("\n✅ Download complete.")
66
-
67
- except Exception as e:
68
- print(f"\n❌ Failed to download model: {e}")
69
- raise
70
- else:
71
- print(f"✅ Model file '{MODEL_FILENAME}' already exists.")
72
-
73
- # ==========================================
74
- # 3. PREPROCESSING
75
- # ==========================================
76
- def preprocess_text(text):
77
- if not isinstance(text, str): return ""
78
- text = re.sub(r'http\S+|www\S+', '[URL]', text)
79
- text = re.sub(r'@\w+', '[USER]', text)
80
- text = re.sub(r'[a-zA-Z]', '', text)
81
- text = re.sub("[إأآ]", "ا", text)
82
- text = re.sub("ة", "ه", text)
83
- text = re.sub("ى", "ي", text)
84
- text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
85
- return text.strip()
86
-
87
- # ==========================================
88
- # 4. MODEL ARCHITECTURE
89
- # ==========================================
90
- class AlHenakiMTLModel(nn.Module):
91
- def __init__(self, n_propaganda, n_emotion):
92
- super(AlHenakiMTLModel, self).__init__()
93
- self.arabert = AutoModel.from_pretrained(MODEL_NAME)
94
- self.hidden_size = 768
95
- self.task_embedding = nn.Embedding(num_embeddings=2, embedding_dim=TASK_EMBED_DIM)
96
- self.head_input_dim = self.hidden_size + TASK_EMBED_DIM
97
- self.prop_head = nn.Linear(self.head_input_dim, n_propaganda)
98
- self.emo_head = nn.Linear(self.head_input_dim, n_emotion)
99
- # Weights used during training loss calc, kept here for structure compatibility
100
- self.log_sigma_prop = nn.Parameter(torch.zeros(1))
101
- self.log_sigma_emo = nn.Parameter(torch.zeros(1))
102
-
103
- def forward(self, input_ids, attention_mask, task_ids):
104
- outputs = self.arabert(input_ids, attention_mask=attention_mask)
105
- pooled_output = outputs.pooler_output
106
- t_embed = self.task_embedding(task_ids)
107
- z = torch.cat((pooled_output, t_embed), dim=1)
108
-
109
- current_task = task_ids[0].item()
110
- if current_task == 0:
111
- return self.prop_head(z), self.log_sigma_prop
112
- elif current_task == 1:
113
- return self.emo_head(z), self.log_sigma_emo
114
- else:
115
- raise ValueError("Unknown Task ID")
116
-
117
- # ==========================================
118
- # 5. INITIALIZE GLOBALS
119
- # ==========================================
120
- # 1. Download
121
- download_model_if_missing()
122
-
123
- # 2. Load Components
124
- print("⏳ Loading Tokenizer & Model...")
125
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
126
- model = AlHenakiMTLModel(len(PROP_CLASSES), len(EMO_CLASSES))
127
-
128
- # 3. Load Weights
129
- try:
130
- state_dict = torch.load(MODEL_FILENAME, map_location=DEVICE)
131
- model.load_state_dict(state_dict)
132
- model.to(DEVICE)
133
- model.eval()
134
- print("✅ Model loaded successfully on", DEVICE)
135
- except Exception as e:
136
- print(f"❌ Critical Error Loading Model: {e}")
137
-
138
- # ==========================================
139
- # 6. INFERENCE LOGIC
140
- # ==========================================
141
- def predict_fn(text, threshold):
142
- clean_text = preprocess_text(text)
143
-
144
- # Empty check
145
- if not clean_text.strip():
146
- return {}, {}, "Please enter Arabic text."
147
-
148
- # Tokenize
149
- inputs = tokenizer(
150
- clean_text,
151
- return_tensors="pt",
152
- max_length=MAX_LEN,
153
- padding="max_length",
154
- truncation=True
155
- ).to(DEVICE)
156
-
157
- input_ids = inputs['input_ids']
158
- attn_mask = inputs['attention_mask']
159
-
160
- with torch.no_grad():
161
- # Propaganda
162
- task_ids_p = torch.tensor([0] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
163
- logits_p, _ = model(input_ids, attn_mask, task_ids_p)
164
- probs_p = torch.sigmoid(logits_p).cpu().numpy()[0]
165
-
166
- # Emotions
167
- task_ids_e = torch.tensor([1] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
168
- logits_e, _ = model(input_ids, attn_mask, task_ids_e)
169
- probs_e = torch.sigmoid(logits_e).cpu().numpy()[0]
170
-
171
- # Format for Gradio Label Output ({Label: Score})
172
- # Filter by threshold AND convert numpy float to native float
173
- prop_results = {
174
- PROP_CLASSES[i]: float(probs_p[i])
175
- for i in range(len(probs_p)) if probs_p[i] > threshold
176
- }
177
-
178
- emo_results = {
179
- EMO_CLASSES[i]: float(probs_e[i])
180
- for i in range(len(probs_e)) if probs_e[i] > threshold
181
- }
182
-
183
- return prop_results, emo_results, f"Processed: {len(clean_text)} chars"
184
-
185
- # ==========================================
186
- # 7. MODERN UI (GRADIO)
187
- # ==========================================
188
- custom_css = """
189
- body { background-color: #f5f7fa; }
190
- .container { max-width: 900px; margin: auto; padding-top: 20px; }
191
- """
192
-
193
- with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AraProp Detector") as demo:
194
- gr.Markdown(
195
- """
196
- # 🕵️‍♂️ Multi-Task Arabic Propaganda & Emotion Detector
197
- ### Based on AraBERT-v02 | SOTA Reproduction
198
- """
199
- )
200
-
201
- with gr.Row():
202
- with gr.Column(scale=1):
203
- input_text = gr.Textbox(
204
- lines=5,
205
- placeholder="أدخل النص هنا للتحليل...",
206
- label="Input Arabic Text",
207
- value="يا له من عار! هذا السياسي يدمر البلاد بخططه الشيطانية الفاشلة."
208
- )
209
-
210
- threshold_slider = gr.Slider(
211
- minimum=0.0,
212
- maximum=1.0,
213
- value=0.4,
214
- step=0.05,
215
- label="Confidence Threshold (Sensitivity)"
216
- )
217
-
218
- run_btn = gr.Button("Analyze Text 🚀", variant="primary")
219
- status_box = gr.Markdown("Ready...")
220
-
221
- with gr.Column(scale=1):
222
- gr.Markdown("### 📊 Detection Results")
223
- # We use 'Label' components which give nice progress bars
224
- out_prop = gr.Label(num_top_classes=8, label="Propaganda Techniques")
225
- out_emo = gr.Label(num_top_classes=8, label="Underlying Emotions")
226
-
227
- # Connect components
228
- run_btn.click(
229
- fn=predict_fn,
230
- inputs=[input_text, threshold_slider],
231
- outputs=[out_prop, out_emo, status_box]
232
- )
233
-
234
- gr.Markdown("---")
235
- gr.Markdown(f"Running on: {DEVICE} | Model: {MODEL_NAME}")
236
-
237
- # ==========================================
238
- # 8. LAUNCH
239
- # ==========================================
240
- if __name__ == "__main__":
241
- # share=True creates a public link for RunPod/Colab
242
  demo.launch(share=True, show_error=True)
 
1
+ import os
2
+ import re
3
+ import requests
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import gradio as gr
8
+ from transformers import AutoTokenizer, AutoModel
9
+ from tqdm import tqdm # Just for download progress bar
10
+
11
+ # ==========================================
12
+ # 1. CONFIGURATION
13
+ # ==========================================
14
+ MODEL_URL = "https://huggingface.co/datasets/mahmoudmohammad/Propaganda_Detection/resolve/main/paper_arch_asl_uw_marbertv2_raw-data.bin"
15
+ MODEL_FILENAME = "paper_arch_asl_uw_marbertv2_raw-data.bin"
16
+ MODEL_NAME = "UBC-NLP/MARBERTv2"
17
+ MAX_LEN = 256
18
+ TASK_EMBED_DIM = 128
19
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
+ # --- Classes (Hardcoded as per your dataset) ---
22
+ PROP_CLASSES = [
23
+ 'Appeal to authority', 'Appeal to fear/prejudice', 'Appeal to time',
24
+ 'Bandwagon', 'Black-and-white Fallacy/Dictatorship',
25
+ 'Causal Oversimplification', 'Doubt', 'Exaggeration/Minimisation',
26
+ 'Flag-waving', 'Glittering generalities (Virtue)', 'Loaded Language',
27
+ "Misrepresentation of Someone's Position (Straw Man)",
28
+ 'Name calling/Labeling', 'Obfuscation, Intentional vagueness, Confusion',
29
+ 'Presenting Irrelevant Data (Red Herring)', 'Repetition', 'Slogans',
30
+ 'Smears', 'Thought-terminating cliché', 'Whataboutism'
31
+ ]
32
+
33
+ EMO_CLASSES = [
34
+ 'anger', 'annoyance', 'anticipation', 'anxiety', 'confusion', 'denial',
35
+ 'disgust', 'empathy', 'fear', 'gratitude', 'humor', 'joy', 'love',
36
+ 'neutral', 'optimism', 'pessimism', 'sadness', 'surprise',
37
+ 'sympathy', 'trust'
38
+ ]
39
+
40
+ # ==========================================
41
+ # 2. HELPER: DOWNLOADER
42
+ # ==========================================
43
+ def download_model_if_missing():
44
+ if not os.path.exists(MODEL_FILENAME):
45
+ print(f"📥 Model file not found. Downloading from Hugging Face...")
46
+ print(f" URL: {MODEL_URL}")
47
+
48
+ try:
49
+ response = requests.get(MODEL_URL, stream=True)
50
+ response.raise_for_status() # Check for error
51
+
52
+ total_size = int(response.headers.get('content-length', 0))
53
+ block_size = 1024 # 1 Kilobyte
54
+
55
+ with open(MODEL_FILENAME, "wb") as file, tqdm(
56
+ desc=MODEL_FILENAME,
57
+ total=total_size,
58
+ unit='iB',
59
+ unit_scale=True,
60
+ unit_divisor=1024,
61
+ ) as bar:
62
+ for data in response.iter_content(block_size):
63
+ size = file.write(data)
64
+ bar.update(size)
65
+ print("\n✅ Download complete.")
66
+
67
+ except Exception as e:
68
+ print(f"\n❌ Failed to download model: {e}")
69
+ raise
70
+ else:
71
+ print(f"✅ Model file '{MODEL_FILENAME}' already exists.")
72
+
73
+ # ==========================================
74
+ # 3. PREPROCESSING
75
+ # ==========================================
76
+ def preprocess_text(text):
77
+ if not isinstance(text, str): return ""
78
+ text = re.sub(r'http\S+|www\S+', '[URL]', text)
79
+ text = re.sub(r'@\w+', '[USER]', text)
80
+ text = re.sub(r'[a-zA-Z]', '', text)
81
+ text = re.sub("[إأآ]", "ا", text)
82
+ text = re.sub("ة", "ه", text)
83
+ text = re.sub("ى", "ي", text)
84
+ text = re.sub(r'[\u0617-\u061A\u064B-\u0652]', '', text)
85
+ return text.strip()
86
+
87
+ # ==========================================
88
+ # 4. MODEL ARCHITECTURE
89
+ # ==========================================
90
+ class AlHenakiMTLModel(nn.Module):
91
+ def __init__(self, n_propaganda, n_emotion):
92
+ super(AlHenakiMTLModel, self).__init__()
93
+ self.arabert = AutoModel.from_pretrained(MODEL_NAME)
94
+ self.hidden_size = 768
95
+ self.task_embedding = nn.Embedding(num_embeddings=2, embedding_dim=TASK_EMBED_DIM)
96
+ self.head_input_dim = self.hidden_size + TASK_EMBED_DIM
97
+ self.prop_head = nn.Linear(self.head_input_dim, n_propaganda)
98
+ self.emo_head = nn.Linear(self.head_input_dim, n_emotion)
99
+ # Weights used during training loss calc, kept here for structure compatibility
100
+ self.log_sigma_prop = nn.Parameter(torch.zeros(1))
101
+ self.log_sigma_emo = nn.Parameter(torch.zeros(1))
102
+
103
+ def forward(self, input_ids, attention_mask, task_ids):
104
+ outputs = self.arabert(input_ids, attention_mask=attention_mask)
105
+ pooled_output = outputs.pooler_output
106
+ t_embed = self.task_embedding(task_ids)
107
+ z = torch.cat((pooled_output, t_embed), dim=1)
108
+
109
+ current_task = task_ids[0].item()
110
+ if current_task == 0:
111
+ return self.prop_head(z), self.log_sigma_prop
112
+ elif current_task == 1:
113
+ return self.emo_head(z), self.log_sigma_emo
114
+ else:
115
+ raise ValueError("Unknown Task ID")
116
+
117
+ # ==========================================
118
+ # 5. INITIALIZE GLOBALS
119
+ # ==========================================
120
+ # 1. Download
121
+ download_model_if_missing()
122
+
123
+ # 2. Load Components
124
+ print("⏳ Loading Tokenizer & Model...")
125
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
126
+ model = AlHenakiMTLModel(len(PROP_CLASSES), len(EMO_CLASSES))
127
+
128
+ # 3. Load Weights
129
+ try:
130
+ state_dict = torch.load(MODEL_FILENAME, map_location=DEVICE)
131
+ model.load_state_dict(state_dict)
132
+ model.to(DEVICE)
133
+ model.eval()
134
+ print("✅ Model loaded successfully on", DEVICE)
135
+ except Exception as e:
136
+ print(f"❌ Critical Error Loading Model: {e}")
137
+
138
+ # ==========================================
139
+ # 6. INFERENCE LOGIC
140
+ # ==========================================
141
+ def predict_fn(text, threshold):
142
+ clean_text = preprocess_text(text)
143
+
144
+ # Empty check
145
+ if not clean_text.strip():
146
+ return {}, {}, "Please enter Arabic text."
147
+
148
+ # Tokenize
149
+ inputs = tokenizer(
150
+ clean_text,
151
+ return_tensors="pt",
152
+ max_length=MAX_LEN,
153
+ padding="max_length",
154
+ truncation=True
155
+ ).to(DEVICE)
156
+
157
+ input_ids = inputs['input_ids']
158
+ attn_mask = inputs['attention_mask']
159
+
160
+ with torch.no_grad():
161
+ # Propaganda
162
+ task_ids_p = torch.tensor([0] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
163
+ logits_p, _ = model(input_ids, attn_mask, task_ids_p)
164
+ probs_p = torch.sigmoid(logits_p).cpu().numpy()[0]
165
+
166
+ # Emotions
167
+ task_ids_e = torch.tensor([1] * input_ids.shape[0], dtype=torch.long).to(DEVICE)
168
+ logits_e, _ = model(input_ids, attn_mask, task_ids_e)
169
+ probs_e = torch.sigmoid(logits_e).cpu().numpy()[0]
170
+
171
+ # Format for Gradio Label Output ({Label: Score})
172
+ # Filter by threshold AND convert numpy float to native float
173
+ prop_results = {
174
+ PROP_CLASSES[i]: float(probs_p[i])
175
+ for i in range(len(probs_p)) if probs_p[i] > threshold
176
+ }
177
+
178
+ emo_results = {
179
+ EMO_CLASSES[i]: float(probs_e[i])
180
+ for i in range(len(probs_e)) if probs_e[i] > threshold
181
+ }
182
+
183
+ return prop_results, emo_results, f"Processed: {len(clean_text)} chars"
184
+
185
+ # ==========================================
186
+ # 7. MODERN UI (GRADIO)
187
+ # ==========================================
188
+ custom_css = """
189
+ .container { max-width: 900px; margin: auto; padding-top: 20px; }
190
+ """
191
+
192
+ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="AraProp Detector", js="() => document.body.classList.add('dark')") as demo:
193
+ gr.Markdown(
194
+ """
195
+ # 🕵️‍♂️ Multi-Task Arabic Propaganda & Emotion Detector
196
+ ### Based on AraBERT-v02 | SOTA Reproduction
197
+ """
198
+ )
199
+
200
+ with gr.Row():
201
+ with gr.Column(scale=1):
202
+ input_text = gr.Textbox(
203
+ lines=5,
204
+ placeholder="أدخل النص هنا للتحليل...",
205
+ label="Input Arabic Text",
206
+ value="يا له من عار! هذا السياسي يدمر البلاد بخططه الشيطانية الفاشلة."
207
+ )
208
+
209
+ threshold_slider = gr.Slider(
210
+ minimum=0.0,
211
+ maximum=1.0,
212
+ value=0.4,
213
+ step=0.05,
214
+ label="Confidence Threshold (Sensitivity)"
215
+ )
216
+
217
+ run_btn = gr.Button("Analyze Text 🚀", variant="primary")
218
+ status_box = gr.Markdown("Ready...")
219
+
220
+ with gr.Column(scale=1):
221
+ gr.Markdown("### 📊 Detection Results")
222
+ # We use 'Label' components which give nice progress bars
223
+ out_prop = gr.Label(num_top_classes=8, label="Propaganda Techniques")
224
+ out_emo = gr.Label(num_top_classes=8, label="Underlying Emotions")
225
+
226
+ # Connect components
227
+ run_btn.click(
228
+ fn=predict_fn,
229
+ inputs=[input_text, threshold_slider],
230
+ outputs=[out_prop, out_emo, status_box]
231
+ )
232
+
233
+ gr.Markdown("---")
234
+ gr.Markdown(f"Running on: {DEVICE} | Model: {MODEL_NAME}")
235
+
236
+ # ==========================================
237
+ # 8. LAUNCH
238
+ # ==========================================
239
+ if __name__ == "__main__":
240
+ # share=True creates a public link for RunPod/Colab
 
241
  demo.launch(share=True, show_error=True)