niel-2ru commited on
Commit
a6d62f5
·
1 Parent(s): a40e942
notebooks/combined-baseline (1).ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/proposed-new-model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, predict_hatespeech_from_file_mock, predict_text_mock
3
  import plotly.graph_objects as go
4
  import plotly.express as px
5
  import pandas as pd
@@ -154,12 +154,22 @@ classify_button = st.button("🔍 Analyze Text", type="primary", use_container_w
154
 
155
  if classify_button:
156
  if user_input and user_input.strip():
157
- with st.spinner('🔄 Analyzing text...'):
158
- # Run both models
159
- enhanced_start = time.time()
 
 
 
 
 
 
 
 
 
 
160
  enhanced_model_result = predict_hatespeech(
161
  text=user_input,
162
- rationale=optional_rationale if optional_rationale else user_input,
163
  model=enhanced_model,
164
  tokenizer_hatebert=enhanced_tokenizer_hatebert,
165
  tokenizer_rationale=enhanced_tokenizer_rationale,
@@ -169,10 +179,11 @@ if classify_button:
169
  )
170
  enhanced_end = time.time()
171
 
 
172
  base_start = time.time()
173
  base_model_result = predict_hatespeech(
174
  text=user_input,
175
- rationale=optional_rationale if optional_rationale else user_input,
176
  model=base_model,
177
  tokenizer_hatebert=base_tokenizer_hatebert,
178
  tokenizer_rationale=base_tokenizer_rationale,
@@ -358,10 +369,6 @@ if classify_button:
358
  }
359
  })
360
  if is_file_uploader_visible and uploaded_file is not None:
361
- st.markdown(f"**Filename:** {uploaded_file.name}")
362
- st.markdown(f"**Size:** {uploaded_file.size / 1024:.2f} KB")
363
- file_rows = len(file_content)
364
- st.metric("Rows in File", file_rows)
365
  st.markdown("**Preview:**")
366
  st.dataframe(file_content.head(3), use_container_width=True)
367
  with st.spinner('🔄 Analyzing file with both models... This may take a while for large files.'):
 
1
  import streamlit as st
2
+ from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
3
  import plotly.graph_objects as go
4
  import plotly.express as px
5
  import pandas as pd
 
154
 
155
  if classify_button:
156
  if user_input and user_input.strip():
157
+ with st.spinner('🔄 Generating rationale from Mistral AI...'):
158
+ # --- Step 1: Get rationale from Mistral ---
159
+ try:
160
+ raw_rationale = get_rationale_from_mistral(user_input)
161
+ cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
162
+ print(f"Raw rationale from Mistral: {raw_rationale}")
163
+ except Exception as e:
164
+ st.error(f"❌ Error generating/processing rationale: {str(e)}")
165
+ cleaned_rationale = user_input # fallback to raw input
166
+
167
+ with st.spinner('🔄 Analyzing text with models...'):
168
+ # Run enhanced model
169
+ enhanced_start = time.time()
170
  enhanced_model_result = predict_hatespeech(
171
  text=user_input,
172
+ rationale=cleaned_rationale, # use cleaned rationale
173
  model=enhanced_model,
174
  tokenizer_hatebert=enhanced_tokenizer_hatebert,
175
  tokenizer_rationale=enhanced_tokenizer_rationale,
 
179
  )
180
  enhanced_end = time.time()
181
 
182
+ # Run base model
183
  base_start = time.time()
184
  base_model_result = predict_hatespeech(
185
  text=user_input,
186
+ rationale=cleaned_rationale, # use cleaned rationale
187
  model=base_model,
188
  tokenizer_hatebert=base_tokenizer_hatebert,
189
  tokenizer_rationale=base_tokenizer_rationale,
 
369
  }
370
  })
371
  if is_file_uploader_visible and uploaded_file is not None:
 
 
 
 
372
  st.markdown("**Preview:**")
373
  st.dataframe(file_content.head(3), use_container_width=True)
374
  with st.spinner('🔄 Analyzing file with both models... This may take a while for large files.'):
src/hatespeech_model.py CHANGED
@@ -9,164 +9,339 @@ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_sc
9
  from time import time
10
  import psutil
11
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # Model Architecture Classes
14
  class TemporalCNN(nn.Module):
15
- def __init__(self, hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4), dropout=0.1, dilation_base=2):
 
 
 
 
 
 
16
  super().__init__()
 
 
17
  self.kernel_sizes = kernel_sizes
18
- self.dilation_base = dilation_base
 
19
  self.convs = nn.ModuleList([
20
- nn.Conv1d(hidden_size, num_filters, k, dilation=dilation_base ** i, padding=0)
21
- for i, k in enumerate(kernel_sizes)
22
  ])
23
  self.dropout = nn.Dropout(dropout)
24
- self.out_dim = num_filters * len(kernel_sizes)
25
-
26
- def _causal_padding(self, x, kernel_size, dilation):
27
- padding = (kernel_size - 1) * dilation
28
- return F.pad(x, (padding, 0))
29
-
30
- def forward(self, x, attention_mask):
31
- mask = attention_mask.unsqueeze(-1)
32
- x = x * mask
33
- x = x.transpose(1, 2)
34
- feats = []
35
- for i, conv in enumerate(self.convs):
36
- kernel_size = self.kernel_sizes[i]
37
- dilation = self.dilation_base ** i
38
- x_padded = self._causal_padding(x, kernel_size, dilation)
39
- c = F.relu(conv(x_padded))
40
- p = F.max_pool1d(c, kernel_size=c.size(2)).squeeze(2)
41
- feats.append(p)
42
- out = torch.cat(feats, dim=1)
43
- return self.dropout(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class MultiScaleAttentionCNN(nn.Module):
46
- def __init__(self, hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4), dropout=0.3):
47
- super().__init__()
48
- self.convs = nn.ModuleList([
49
- nn.Conv1d(hidden_size, num_filters, k) for k in kernel_sizes
50
- ])
51
- self.attention_fc = nn.Linear(num_filters, 1)
52
- self.dropout = nn.Dropout(dropout)
53
- self.out_dim = num_filters * len(kernel_sizes)
54
-
55
- def forward(self, x, mask):
56
- x = x.transpose(1, 2)
57
- feats = []
58
- for conv in self.convs:
59
- h = F.relu(conv(x))
60
- h = h.transpose(1, 2)
61
- attn = self.attention_fc(h).squeeze(-1)
62
- attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)
63
- alpha = F.softmax(attn, dim=1)
64
- pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)
65
- feats.append(pooled)
66
- out = torch.cat(feats, dim=1)
67
- return self.dropout(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  class ProjectionMLP(nn.Module):
70
- def __init__(self, input_size, hidden_size, num_labels):
71
  super().__init__()
72
  self.layers = nn.Sequential(
73
  nn.Linear(input_size, hidden_size),
74
  nn.ReLU(),
75
  nn.Linear(hidden_size, num_labels)
76
  )
77
-
78
  def forward(self, x):
79
  return self.layers(x)
80
 
81
- class GumbelTokenSelector(nn.Module):
82
- def __init__(self, hidden_size, tau=1.0):
83
- super().__init__()
84
- self.tau = tau
85
- self.proj = nn.Linear(hidden_size * 2, 1)
86
-
87
- def forward(self, token_embeddings, cls_embedding, training=True):
88
- B, L, H = token_embeddings.size()
89
- cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)
90
- x = torch.cat([token_embeddings, cls_exp], dim=-1)
91
- logits = self.proj(x).squeeze(-1)
92
-
93
- if training:
94
- probs = F.gumbel_softmax(
95
- torch.stack([logits, torch.zeros_like(logits)], dim=-1),
96
- tau=self.tau,
97
- hard=False
98
- )[..., 0]
99
- else:
100
- probs = torch.sigmoid(logits)
101
- return probs, logits
102
 
103
- class BaseShield(nn.Module):
104
- """
105
- Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
106
- """
107
- def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu',
108
- freeze_additional_model=True):
 
 
 
 
 
109
  super().__init__()
110
  self.hatebert_model = hatebert_model
111
  self.additional_model = additional_model
112
  self.projection_mlp = projection_mlp
113
- self.device = device
114
-
 
115
  if freeze_additional_model:
116
  for param in self.additional_model.parameters():
117
  param.requires_grad = False
118
-
119
- def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
120
- hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
121
- hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
122
- hatebert_embeddings = torch.nn.LayerNorm(hatebert_embeddings.size()[1:]).to(self.device)(hatebert_embeddings.to(self.device)).to(self.device)
123
 
124
- additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
125
- additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
126
- additional_embeddings = torch.nn.LayerNorm(additional_embeddings.size()[1:]).to(self.device)(additional_embeddings.to(self.device)).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1).to(self.device)
129
- projected_embeddings = self.projection_mlp(concatenated_embeddings).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Return 4 values to match ConcatModel interface (rationale_probs, selector_logits, attentions are None)
132
- return projected_embeddings
133
 
134
- class ConcatModel(nn.Module):
135
- def __init__(self, hatebert_model, additional_model, temporal_cnn, msa_cnn, selector, projection_mlp, freeze_additional_model=True, freeze_hatebert=True):
 
 
 
 
136
  super().__init__()
137
  self.hatebert_model = hatebert_model
138
  self.additional_model = additional_model
139
- self.temporal_cnn = temporal_cnn
140
- self.msa_cnn = msa_cnn
141
- self.selector = selector
142
  self.projection_mlp = projection_mlp
 
143
 
144
  if freeze_additional_model:
145
- for p in self.additional_model.parameters():
146
- p.requires_grad = False
147
- if freeze_hatebert:
148
- for p in self.hatebert_model.parameters():
149
- p.requires_grad = False
150
 
151
  def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
152
- hate_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
153
- seq_emb = hate_outputs.last_hidden_state
154
- cls_emb = seq_emb[:, 0, :]
155
-
156
- token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)
157
- temporal_feat = self.temporal_cnn(seq_emb, attention_mask)
158
-
159
- weights = token_probs.unsqueeze(-1)
160
- H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)
161
-
162
- with torch.no_grad():
163
- add_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
164
- add_seq = add_outputs.last_hidden_state
165
-
166
- msa_feat = self.msa_cnn(add_seq, additional_attention_mask)
167
- concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)
168
- logits = self.projection_mlp(concat)
169
- return logits, token_probs, token_logits, hate_outputs.attentions if hasattr(hate_outputs, "attentions") else None
170
 
171
  def load_model_from_hf(model_type="altered"):
172
  """
@@ -178,14 +353,13 @@ def load_model_from_hf(model_type="altered"):
178
 
179
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
180
  repo_id = "seffyehl/BetterShield"
181
- # repo_type = "e5912f6e8c34a10629cfd5a7971ac71ac76d0e9d"
182
 
183
  # Choose model and config files based on model_type
184
  if model_type.lower() == "altered":
185
  model_filename = "AlteredShield.pth"
186
  config_filename = "alter_config.json"
187
  elif model_type.lower() == "base":
188
- model_filename = "BaseShield.pth"
189
  config_filename = "base_config.json"
190
  else:
191
  raise ValueError(f"model_type must be 'altered' or 'base', got '{model_type}'")
@@ -193,22 +367,24 @@ def load_model_from_hf(model_type="altered"):
193
  # Download files
194
  model_path = hf_hub_download(
195
  repo_id=repo_id,
196
- # revision=repo_type,
197
  filename=model_filename
198
  )
199
 
200
  config_path = hf_hub_download(
201
  repo_id=repo_id,
202
  filename=config_filename,
203
- # revision=repo_type
204
  )
205
 
206
  # Load config
207
  with open(config_path, 'r') as f:
208
  config = json.load(f)
209
 
210
- # Load checkpoint
211
- checkpoint = torch.load(model_path, map_location='cpu')
 
 
 
 
212
 
213
  # Handle nested config structure (base model uses model_config, altered uses flat structure)
214
  if 'model_config' in config:
@@ -225,50 +401,144 @@ def load_model_from_hf(model_type="altered"):
225
  tokenizer_hatebert = AutoTokenizer.from_pretrained(model_config['hatebert_model'])
226
  tokenizer_rationale = AutoTokenizer.from_pretrained(model_config['rationale_model'])
227
 
228
- # Rebuild architecture based on model type
229
  H = hatebert_model.config.hidden_size
230
  max_length = training_config.get('max_length', 128)
231
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  if model_type.lower() == "base":
233
- # Base Shield: Simple concatenation model
234
- # Input: 768 (HateBERT CLS) + 768 (Rationale BERT CLS) = 1536
235
- proj_input_dim = H * 2 # 1536
236
- # The saved model uses 512, not what's in projection_config
237
- adapter_dim = 512 # hardcoded to match saved weights
238
- projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim,
239
- num_labels=2)
240
-
241
  model = BaseShield(
242
  hatebert_model=hatebert_model,
243
  additional_model=rationale_model,
244
  projection_mlp=projection_mlp,
245
- freeze_additional_model=True,
246
  device=device
247
  ).to(device)
248
  else:
249
- temporal_cnn = TemporalCNN(hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4)).to(device)
250
- msa_cnn = MultiScaleAttentionCNN(hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4)).to(device)
251
- selector = GumbelTokenSelector(hidden_size=768, tau=1.0).to(device)
252
- projection_mlp = ProjectionMLP(input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2, hidden_size=512, num_labels=2).to(device)
253
- model = ConcatModel(
254
- hatebert_model=hatebert_model,
255
- additional_model=rationale_model,
256
- temporal_cnn=temporal_cnn,
257
- msa_cnn=msa_cnn,
258
- selector=selector,
259
- projection_mlp=projection_mlp,
260
- freeze_additional_model=True,
261
- freeze_hatebert=True).to(device)
 
 
 
 
 
 
 
 
 
 
 
262
 
 
263
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
264
- model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
266
  print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
267
- else:
268
- model.load_state_dict(checkpoint)
269
  model.eval()
 
 
 
 
 
 
270
  model = model.to(device)
271
 
 
 
 
 
272
  # Create a unified config dict with max_length at top level for compatibility
273
  unified_config = config.copy()
274
  if 'max_length' not in unified_config and 'training_config' in config:
@@ -276,26 +546,33 @@ def load_model_from_hf(model_type="altered"):
276
 
277
  return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
280
  device='cpu', max_length=128, model_type="altered"):
281
- """
282
- Predict hate speech for a given text and rationale
283
-
284
- Args:
285
- text: Input text to classify
286
- rationale: Rationale/explanation text
287
- model: Loaded model
288
- tokenizer_hatebert: HateBERT tokenizer
289
- tokenizer_rationale: Rationale model tokenizer
290
- device: 'cpu' or 'cuda'
291
- max_length: Maximum sequence length
292
- model_type: Either "altered" or "base" to determine how to process inputs
293
-
294
- Returns:
295
- prediction: 0 or 1
296
- probability: Confidence score
297
- rationale_scores: Token-level rationale scores
298
- """
299
  model.eval()
300
 
301
  # Tokenize inputs
@@ -321,79 +598,99 @@ def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale
321
  add_input_ids = inputs_rationale['input_ids'].to(device)
322
  add_attention_mask = inputs_rationale['attention_mask'].to(device)
323
 
324
- # Inference
325
- if model_type.lower() == "base":
326
- with torch.no_grad():
327
  logits = model(
328
  input_ids,
329
  attention_mask,
330
  add_input_ids,
331
  add_attention_mask
332
  )
333
-
334
- # Get probabilities
335
- probs = torch.softmax(logits, dim=1)
336
- prediction = logits.argmax(dim=1).item()
337
- confidence = probs[0, prediction].item()
 
 
338
 
339
- return {
340
- 'prediction': prediction,
341
- 'confidence': confidence,
342
- 'probabilities': probs[0].cpu().numpy(),
343
- 'rationale_scores': None, # Base model does not produce token-level rationale scores
344
- 'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
345
- }
346
-
347
- with torch.no_grad():
348
- logits, rationale_probs, selector_logits, _ = model(
349
- input_ids,
350
- attention_mask,
351
- add_input_ids,
352
- add_attention_mask
353
- )
354
 
355
- # Get probabilities
356
- probs = torch.softmax(logits, dim=1)
357
  prediction = logits.argmax(dim=1).item()
358
  confidence = probs[0, prediction].item()
 
 
 
359
 
360
- return {
361
  'prediction': prediction,
362
  'confidence': confidence,
363
  'probabilities': probs[0].cpu().numpy(),
364
- 'rationale_scores': rationale_probs[0].cpu().numpy(),
365
  'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
366
  }
367
-
368
- def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
369
- """
370
- Predict hate speech for text read from a file
371
 
372
- Args:
373
- text_list: List of input texts to classify
374
- rationale_list: List of rationale/explanation texts
375
- true_label: True label for evaluation
376
- model: Loaded model
377
- tokenizer_hatebert: HateBERT tokenizer
378
- tokenizer_rationale: Rationale tokenizer
379
- config: Model configuration
380
- device: Device to run on
381
- Returns:
382
- f1_score: F1 score for the predictions
383
- accuracy: Accuracy for the predictions
384
- precision: Precision for the predictions
385
- recall: Recall for the predictions
386
- confusion_matrix: Confusion matrix as a 2D list
387
- cpu_usage: CPU usage during prediction
388
- memory_usage: Memory usage during prediction
389
- runtime: Total runtime for predictions
390
- """
 
 
 
391
  predictions = []
 
392
  cpu_percent_list = []
393
  memory_percent_list = []
394
 
395
  process = psutil.Process(os.getpid())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  start_time = time()
 
397
  for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
398
  result = predict_text(
399
  text=text,
@@ -405,27 +702,45 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
405
  max_length=config.get('max_length', 128),
406
  model_type=model_type
407
  )
 
408
  predictions.append(result['prediction'])
409
- # Log resource usage every 10th sample and at end to reduce overhead
 
 
410
  if idx % 10 == 0 or idx == len(text_list) - 1:
411
  cpu_percent_list.append(process.cpu_percent())
412
  memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
413
 
 
 
 
 
414
  end_time = time()
415
  runtime = end_time - start_time
416
- # Calculate metrics
 
 
 
 
 
 
 
 
 
 
417
  f1 = f1_score(true_label, predictions, zero_division=0)
418
  accuracy = accuracy_score(true_label, predictions)
419
  precision = precision_score(true_label, predictions, zero_division=0)
420
  recall = recall_score(true_label, predictions, zero_division=0)
421
  cm = confusion_matrix(true_label, predictions).tolist()
422
-
423
  avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
424
- avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
425
  peak_memory = max(memory_percent_list) if memory_percent_list else 0
426
  peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
427
 
428
  return {
 
429
  'f1_score': f1,
430
  'accuracy': accuracy,
431
  'precision': precision,
@@ -435,25 +750,14 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
435
  'memory_usage': avg_memory,
436
  'peak_cpu_usage': peak_cpu,
437
  'peak_memory_usage': peak_memory,
438
- 'runtime': runtime
 
439
  }
440
 
441
 
442
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
443
  """
444
  Predict hate speech for given text
445
-
446
- Args:
447
- text: Input text to classify
448
- rationale: Optional rationale text
449
- model: Loaded model
450
- tokenizer_hatebert: HateBERT tokenizer
451
- tokenizer_rationale: Rationale tokenizer
452
- config: Model configuration
453
- device: Device to run on
454
-
455
- Returns:
456
- Dictionary with prediction results
457
  """
458
  # Get prediction
459
  result = predict_text(
@@ -468,88 +772,3 @@ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rat
468
  )
469
 
470
  return result
471
-
472
- def predict_hatespeech_from_file_mock():
473
- """
474
- Mock function for predict_hatespeech_from_file that returns hardcoded data for testing
475
-
476
- Args:
477
- text_list: List of input texts to classify (not used in mock)
478
- rationale_list: List of rationale/explanation texts (not used in mock)
479
- true_label: True label for evaluation (not used in mock)
480
- model: Loaded model (not used in mock)
481
- tokenizer_hatebert: HateBERT tokenizer (not used in mock)
482
- tokenizer_rationale: Rationale tokenizer (not used in mock)
483
- config: Model configuration (not used in mock)
484
- device: Device to run on (not used in mock)
485
- Returns:
486
- Dictionary with hardcoded metrics for testing
487
- """
488
- # Hardcoded predictions matching the number of samples
489
- predictions = [0, 1, 1, 0, 1, 0, 0, 1, 1, 0]
490
- true_labels = [0, 1, 1, 0, 0, 0, 1, 1, 1, 0]
491
-
492
- # Hardcoded resource usage metrics
493
- cpu_percent_list = [25.3, 28.1, 26.5, 27.2, 26.8, 27.9, 25.5, 28.3, 26.2, 27.1]
494
- memory_percent_list = [145.3, 152.1, 148.5, 151.2, 149.8, 153.2, 146.5, 154.3, 150.2, 152.1]
495
-
496
- f1 = f1_score(true_labels, predictions, zero_division=0)
497
- accuracy = accuracy_score(true_labels, predictions)
498
- precision = precision_score(true_labels, predictions, zero_division=0)
499
- recall = recall_score(true_labels, predictions, zero_division=0)
500
- cm = confusion_matrix(true_labels, predictions).tolist()
501
-
502
- avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
503
- avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
504
- peak_memory = max(memory_percent_list) if memory_percent_list else 0
505
- peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
506
-
507
- # Hardcoded runtime
508
- runtime = 12.543
509
-
510
- return {
511
- 'f1_score': f1,
512
- 'accuracy': accuracy,
513
- 'precision': precision,
514
- 'recall': recall,
515
- 'confusion_matrix': cm,
516
- 'cpu_usage': avg_cpu,
517
- 'memory_usage': avg_memory,
518
- 'peak_cpu_usage': peak_cpu,
519
- 'peak_memory_usage': peak_memory,
520
- 'runtime': runtime,
521
- 'predictions': predictions # Added for visibility
522
- }
523
-
524
- def predict_text_mock(text, max_length=128):
525
- import numpy as np
526
-
527
- # Simple whitespace tokenization for mock output
528
- raw_tokens = (text or "").split()
529
- mock_tokens = raw_tokens[:max_length]
530
-
531
- # Build a simple attention mask (1 for tokens)
532
- attention_mask = [1] * len(mock_tokens)
533
-
534
- # Generate random rationale scores matching token count
535
- mock_rationale_scores = np.random.rand(len(mock_tokens)).astype(np.float32)
536
-
537
- # Randomized probabilities [class_0, class_1]
538
- # Class 0 = not hate speech, Class 1 = hate speech
539
- mock_probabilities = np.random.rand(2).astype(np.float32)
540
- mock_probabilities = mock_probabilities / mock_probabilities.sum()
541
-
542
- # Prediction (argmax of probabilities)
543
- mock_prediction = int(np.argmax(mock_probabilities)) # Class 1: hate speech
544
-
545
- # Confidence score
546
- mock_confidence = float(np.max(mock_probabilities))
547
-
548
- return {
549
- 'prediction': mock_prediction,
550
- 'confidence': mock_confidence,
551
- 'probabilities': mock_probabilities,
552
- 'rationale_scores': mock_rationale_scores,
553
- 'tokens': mock_tokens,
554
- 'attention_mask': attention_mask
555
- }
 
9
  from time import time
10
  import psutil
11
  import os
12
+ import numpy as np
13
+ import requests
14
+ import json
15
+
16
+ API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/8fcfcf97aa4c166eee626b79a67f902d/ai/run/"
17
+ HEADERS = {"Authorization": "Bearer 2Qb-uZ6M8yzkKZmGmcxZGRveNvk3YXBJwhlQyOfP"}
18
+ MODEL_NAME = "@cf/mistralai/mistral-small-3.1-24b-instruct"
19
+
20
+ def create_prompt(text):
21
+ return f"""
22
+ You are a content moderation assistant. Identify the list of [rationales] words or phrases from the text that make it hateful,
23
+ list of [derogatory language], and [list of cuss words] and [hate_classification] such as "hateful" or "non-hateful".
24
+ If there are none, respond exactly with [non-hateful] only.
25
+ Output should be in JSON format only. Text: {text}.
26
+ """
27
+
28
+ def run_mistral_model(model, inputs):
29
+ payload = {"messages": inputs}
30
+ response = requests.post(f"{API_BASE_URL}{model}", headers=HEADERS, json=payload)
31
+ response.raise_for_status()
32
+ return response.json()
33
+
34
+ def flatten_json_string(json_string):
35
+ try:
36
+ obj = json.loads(json_string)
37
+ return json.dumps(obj, separators=(",", ":"))
38
+ except:
39
+ return json_string
40
+
41
+ def get_rationale_from_mistral(text, retries=10):
42
+ """
43
+ Sends text to Mistral AI and returns a cleaned JSON rationale string.
44
+ Retries if the model returns invalid output or starts with "I cannot".
45
+ """
46
+ for attempt in range(retries):
47
+ try:
48
+ inputs = [{"role": "user", "content": create_prompt(text)}]
49
+ output = run_mistral_model(MODEL_NAME, inputs)
50
+
51
+ result = output.get("result", {})
52
+ response_text = result.get("response", "").strip()
53
+
54
+ if not response_text or response_text.startswith("I cannot"):
55
+ print(f"⚠️ Model returned 'I cannot...' — retrying ({attempt+1}/{retries})")
56
+ continue # retry
57
+
58
+ # Flatten JSON response and clean
59
+ cleaned_rationale = flatten_json_string(response_text).replace("\n", " ").strip()
60
+ return cleaned_rationale
61
+
62
+ except requests.exceptions.HTTPError as e:
63
+ print(f"⚠️ HTTP Error on attempt {attempt+1}: {e}")
64
+ # If resource exhausted or rate limited, raise
65
+ if "RESOURCE_EXHAUSTED" in str(e) or e.response.status_code == 429:
66
+ raise
67
+
68
+ # Fallback if all retries fail
69
+ return "non-hateful"
70
+
71
+ def preprocess_rationale_mistral(raw_rationale):
72
+ """
73
+ Cleans and standardizes rationale text from Mistral AI.
74
+ - Removes ```json fences
75
+ - Fixes escaped quotes
76
+ - Extracts JSON content
77
+ - Returns 'non-hateful' if all rationale lists are empty
78
+ - Otherwise returns a clean, one-line JSON of rationales
79
+ """
80
+ try:
81
+ x = str(raw_rationale).strip()
82
+
83
+ # Remove ```json fences
84
+ if x.startswith("```"):
85
+ x = x.replace("```json", "").replace("```", "").strip()
86
+
87
+ # Fix double quotes
88
+ x = x.replace('""', '"')
89
+
90
+ # Extract JSON object
91
+ start = x.find("{")
92
+ end = x.rfind("}") + 1
93
+ if start == -1 or end == -1:
94
+ return x.lower() # fallback
95
+
96
+ j = json.loads(x[start:end])
97
+
98
+ keys = ["rationales", "derogatory_language", "cuss_words"]
99
+
100
+ # If all lists exist and are empty → non-hateful
101
+ if all(k in j and isinstance(j[k], list) and len(j[k]) == 0 for k in keys):
102
+ return "non-hateful"
103
+
104
+ # Otherwise, return clean JSON of relevant keys
105
+ cleaned = {k: j.get(k, []) for k in keys}
106
+ return json.dumps(cleaned).lower()
107
+
108
+ except Exception:
109
+ return str(raw_rationale).lower()
110
 
111
  # Model Architecture Classes
112
  class TemporalCNN(nn.Module):
113
+ """
114
+ Temporal CNN applied across the sequence (time) dimension.
115
+ Input: sequence_embeddings (B, L, H), attention_mask (B, L)
116
+ Output: pooled vector (B, output_dim) where output_dim = num_filters * len(kernel_sizes) * 2
117
+ (we concatenate max-pooled and mean-pooled features for each kernel size)
118
+ """
119
+ def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2, 3, 4), dropout=0.3):
120
  super().__init__()
121
+ self.input_dim = input_dim
122
+ self.num_filters = num_filters
123
  self.kernel_sizes = kernel_sizes
124
+
125
+ # Convs expect (B, C_in, L) where C_in = input_dim
126
  self.convs = nn.ModuleList([
127
+ nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k // 2)
128
+ for k in kernel_sizes
129
  ])
130
  self.dropout = nn.Dropout(dropout)
131
+
132
+ def forward(self, sequence_embeddings, attention_mask=None):
133
+ """
134
+ sequence_embeddings: (B, L, H)
135
+ attention_mask: (B, L) with 1 for valid tokens, 0 for padding
136
+ returns: (B, num_filters * len(kernel_sizes) * 2) # max + mean pooled per conv
137
+ """
138
+ # transpose to (B, H, L)
139
+ x = sequence_embeddings.transpose(1, 2).contiguous() # (B, H, L)
140
+
141
+ pooled_outputs = []
142
+ for conv in self.convs:
143
+ conv_out = conv(x) # (B, num_filters, L_out)
144
+ conv_out = F.relu(conv_out)
145
+ L_out = conv_out.size(2)
146
+
147
+ if attention_mask is not None:
148
+ # resize mask to match L_out
149
+ mask = attention_mask.float()
150
+ if mask.size(1) != L_out:
151
+ mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
152
+ mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
153
+
154
+ # max pool with masking
155
+ neg_inf = torch.finfo(conv_out.dtype).min / 2
156
+ max_masked = torch.where(mask.bool(), conv_out, neg_inf * torch.ones_like(conv_out))
157
+ max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
158
+
159
+ # mean pool with masking
160
+ sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
161
+ denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
162
+ mean_pooled = sum_masked / denom # (B, num_filters)
163
+ else:
164
+ max_pooled = torch.max(conv_out, dim=2)[0]
165
+ mean_pooled = conv_out.mean(dim=2)
166
+
167
+ pooled_outputs.append(max_pooled)
168
+ pooled_outputs.append(mean_pooled)
169
+
170
+ out = torch.cat(pooled_outputs, dim=1) # (B, num_filters * len(kernel_sizes) * 2)
171
+ out = self.dropout(out)
172
+ return out
173
+
174
 
175
  class MultiScaleAttentionCNN(nn.Module):
176
+ def __init__(self, hidden_size=768, num_filters=64, kernel_sizes=(2, 3, 4, 5, 6, 7), dropout=0.3):
177
+ super().__init__()
178
+
179
+ self.hidden_size = hidden_size
180
+ self.kernel_sizes = kernel_sizes
181
+
182
+ self.convs = nn.ModuleList()
183
+ self.pads = nn.ModuleList()
184
+
185
+ for k in self.kernel_sizes:
186
+ pad_left = (k - 1) // 2
187
+ pad_right = k - 1 - pad_left
188
+ self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
189
+ self.convs.append(nn.Conv1d(hidden_size, num_filters, kernel_size=k, padding=0))
190
+
191
+ self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
192
+ self.output_size = num_filters * len(self.kernel_sizes)
193
+ self.dropout = nn.Dropout(dropout)
194
+
195
+ def forward(self, hidden_states, mask):
196
+ """
197
+ hidden_states: (B, L, H)
198
+ mask: (B, L)
199
+ """
200
+ x = hidden_states.transpose(1, 2) # (B, H, L)
201
+ attn_mask = mask.unsqueeze(1).float()
202
+
203
+ conv_outs = []
204
+
205
+ for pad, conv, att in zip(self.pads, self.convs, self.attn):
206
+ padded = pad(x) # (B, H, L)
207
+ c = conv(padded) # (B, F, L)
208
+ c = F.relu(c)
209
+ c = c * attn_mask
210
+
211
+ c_t = c.transpose(1, 2) # (B, L, F)
212
+ w = att(c_t) # (B, L, 1)
213
+ w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
214
+ w = F.softmax(w, dim=1)
215
+
216
+ pooled = (c_t * w).sum(dim=1) # (B, F)
217
+ conv_outs.append(pooled)
218
+
219
+ out = torch.cat(conv_outs, dim=1) # (B, F * K)
220
+ return self.dropout(out)
221
+
222
 
223
  class ProjectionMLP(nn.Module):
224
+ def __init__(self, input_size, hidden_size=256, num_labels=2):
225
  super().__init__()
226
  self.layers = nn.Sequential(
227
  nn.Linear(input_size, hidden_size),
228
  nn.ReLU(),
229
  nn.Linear(hidden_size, num_labels)
230
  )
231
+
232
  def forward(self, x):
233
  return self.layers(x)
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ class ConcatModelWithRationale(nn.Module):
237
+ def __init__(self,
238
+ hatebert_model,
239
+ additional_model,
240
+ projection_mlp,
241
+ hidden_size=768,
242
+ gumbel_temp=0.5,
243
+ freeze_additional_model=True,
244
+ cnn_num_filters=64,
245
+ cnn_kernel_sizes=(2, 3, 4, 5, 6, 7),
246
+ cnn_dropout=0.0):
247
  super().__init__()
248
  self.hatebert_model = hatebert_model
249
  self.additional_model = additional_model
250
  self.projection_mlp = projection_mlp
251
+ self.gumbel_temp = gumbel_temp
252
+ self.hidden_size = hidden_size
253
+
254
  if freeze_additional_model:
255
  for param in self.additional_model.parameters():
256
  param.requires_grad = False
 
 
 
 
 
257
 
258
+ # selector head (per-token logits)
259
+ self.selector = nn.Linear(hidden_size, 1)
260
+
261
+ # Temporal CNN over HateBERT embeddings (main text)
262
+ self.temporal_cnn = TemporalCNN(input_dim=hidden_size,
263
+ num_filters=cnn_num_filters,
264
+ kernel_sizes=cnn_kernel_sizes,
265
+ dropout=cnn_dropout)
266
+ self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
267
+
268
+ # MultiScaleAttentionCNN over rationale embeddings (frozen BERT)
269
+ self.msa_cnn = MultiScaleAttentionCNN(hidden_size=hidden_size,
270
+ num_filters=cnn_num_filters,
271
+ kernel_sizes=cnn_kernel_sizes,
272
+ dropout=cnn_dropout)
273
+ self.msa_out_dim = self.msa_cnn.output_size
274
+
275
+ def gumbel_sigmoid_sample(self, logits):
276
+ noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
277
+ y = logits + noise
278
+ return torch.sigmoid(y / self.gumbel_temp)
279
+
280
+ def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask, return_attentions=False):
281
+ # Main text through HateBERT
282
+ hatebert_out = self.hatebert_model(input_ids=input_ids,
283
+ attention_mask=attention_mask,
284
+ output_attentions=return_attentions,
285
+ return_dict=True)
286
+ hatebert_emb = hatebert_out.last_hidden_state # (B, L, H)
287
+ cls_emb = hatebert_emb[:, 0, :] # (B, H)
288
 
289
+ # Rationale text through frozen BERT
290
+ with torch.no_grad():
291
+ add_out = self.additional_model(input_ids=additional_input_ids,
292
+ attention_mask=additional_attention_mask,
293
+ return_dict=True)
294
+ rationale_emb = add_out.last_hidden_state # (B, L, H)
295
+
296
+ # selector logits & Gumbel-Sigmoid sampling on HateBERT
297
+ selector_logits = self.selector(hatebert_emb).squeeze(-1) # (B, L)
298
+ rationale_probs = self.gumbel_sigmoid_sample(selector_logits) # (B, L)
299
+ rationale_probs = rationale_probs * attention_mask.float().to(rationale_probs.device)
300
+
301
+ # pooled rationale summary
302
+ masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
303
+ denom = rationale_probs.sum(1).unsqueeze(-1).clamp_min(1e-6)
304
+ pooled_rationale = masked_hidden.sum(1) / denom # (B, H)
305
+
306
+ # CNN branches
307
+ temporal_features = self.temporal_cnn(hatebert_emb, attention_mask) # (B, temporal_out_dim)
308
+ rationale_features = self.msa_cnn(rationale_emb, additional_attention_mask) # (B, msa_out_dim)
309
+
310
+ # concat CLS + CNN features + pooled rationale
311
+ concat_emb = torch.cat((cls_emb, temporal_features, rationale_features, pooled_rationale), dim=1)
312
+
313
+ logits = self.projection_mlp(concat_emb)
314
+
315
+ attns = hatebert_out.attentions if (return_attentions and hasattr(hatebert_out, "attentions")) else None
316
+ return logits, rationale_probs, selector_logits, attns
317
 
 
 
318
 
319
+ class BaseShield(nn.Module):
320
+ """
321
+ Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
322
+ and projects to label logits via a small MLP.
323
+ """
324
+ def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
325
  super().__init__()
326
  self.hatebert_model = hatebert_model
327
  self.additional_model = additional_model
 
 
 
328
  self.projection_mlp = projection_mlp
329
+ self.device = device
330
 
331
  if freeze_additional_model:
332
+ for param in self.additional_model.parameters():
333
+ param.requires_grad = False
 
 
 
334
 
335
  def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
336
+ hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
337
+ hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
338
+
339
+ additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask, return_dict=True)
340
+ additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
341
+
342
+ concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
343
+ logits = self.projection_mlp(concatenated_embeddings)
344
+ return logits
 
 
 
 
 
 
 
 
 
345
 
346
  def load_model_from_hf(model_type="altered"):
347
  """
 
353
 
354
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
355
  repo_id = "seffyehl/BetterShield"
 
356
 
357
  # Choose model and config files based on model_type
358
  if model_type.lower() == "altered":
359
  model_filename = "AlteredShield.pth"
360
  config_filename = "alter_config.json"
361
  elif model_type.lower() == "base":
362
+ model_filename = "BaselineShield.pth"
363
  config_filename = "base_config.json"
364
  else:
365
  raise ValueError(f"model_type must be 'altered' or 'base', got '{model_type}'")
 
367
  # Download files
368
  model_path = hf_hub_download(
369
  repo_id=repo_id,
 
370
  filename=model_filename
371
  )
372
 
373
  config_path = hf_hub_download(
374
  repo_id=repo_id,
375
  filename=config_filename,
 
376
  )
377
 
378
  # Load config
379
  with open(config_path, 'r') as f:
380
  config = json.load(f)
381
 
382
+ # Load checkpoint with proper handling for numpy dtypes (PyTorch 2.6+ compatibility)
383
+ try:
384
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
385
+ except TypeError:
386
+ # Fallback for older PyTorch versions
387
+ checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
388
 
389
  # Handle nested config structure (base model uses model_config, altered uses flat structure)
390
  if 'model_config' in config:
 
401
  tokenizer_hatebert = AutoTokenizer.from_pretrained(model_config['hatebert_model'])
402
  tokenizer_rationale = AutoTokenizer.from_pretrained(model_config['rationale_model'])
403
 
404
+ # Rebuild architecture based on model type using training_config values when available
405
  H = hatebert_model.config.hidden_size
406
  max_length = training_config.get('max_length', 128)
407
+
408
+ # common params from training config (use None to allow inference from checkpoint)
409
+ adapter_dim = training_config.get('adapter_dim', training_config.get('adapter_size', None))
410
+ cnn_num_filters = training_config.get('cnn_num_filters', None)
411
+ cnn_kernel_sizes = training_config.get('cnn_kernel_sizes', None)
412
+ cnn_dropout = training_config.get('cnn_dropout', 0.3)
413
+ freeze_rationale = training_config.get('freeze_additional_model', True)
414
+ num_labels = training_config.get('num_labels', 2)
415
+
416
+ # Infer architecture params from checkpoint state_dict when possible to match saved weights
417
+ state_dict = None
418
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
419
+ state_dict = checkpoint['model_state_dict']
420
+ elif isinstance(checkpoint, dict):
421
+ # sometimes checkpoint is a raw state_dict saved as dict
422
+ state_dict = checkpoint
423
+
424
+ if state_dict is not None:
425
+ # infer temporal convs count and filters if present
426
+ temporal_keys = [k for k in state_dict.keys() if k.startswith('temporal_cnn.convs.') and k.endswith('.weight')]
427
+ if temporal_keys:
428
+ try:
429
+ sample = state_dict[temporal_keys[0]]
430
+ inferred_num_filters = sample.shape[0]
431
+ inferred_kernel_count = len(temporal_keys)
432
+ if cnn_num_filters is None:
433
+ cnn_num_filters = int(inferred_num_filters)
434
+ if cnn_kernel_sizes is None:
435
+ cnn_kernel_sizes = training_config.get('cnn_kernel_sizes', (2,3,4,5,6,7))
436
+ except Exception:
437
+ pass
438
+
439
+ # infer projection dims/adapt size
440
+ proj_w_key = None
441
+ for key in ('projection_mlp.layers.0.weight', 'projection_mlp.0.weight', 'projection_mlp.layers.0.weight_orig'):
442
+ if key in state_dict:
443
+ proj_w_key = key
444
+ break
445
+ if proj_w_key is not None:
446
+ try:
447
+ proj_w = state_dict[proj_w_key]
448
+ inferred_adapter_dim = proj_w.shape[0]
449
+ if adapter_dim is None:
450
+ adapter_dim = int(inferred_adapter_dim)
451
+ except Exception:
452
+ pass
453
+
454
+ # sensible defaults when neither config nor checkpoint provided values
455
+ if cnn_num_filters is None:
456
+ cnn_num_filters = 64 # Changed from 128 to match typical training configs
457
+ if cnn_kernel_sizes is None:
458
+ cnn_kernel_sizes = (2, 3, 4, 5, 6, 7)
459
+ if adapter_dim is None:
460
+ adapter_dim = 128
461
+
462
  if model_type.lower() == "base":
463
+ proj_input_dim = H * 2
464
+ projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
 
 
 
 
 
 
465
  model = BaseShield(
466
  hatebert_model=hatebert_model,
467
  additional_model=rationale_model,
468
  projection_mlp=projection_mlp,
469
+ freeze_additional_model=freeze_rationale,
470
  device=device
471
  ).to(device)
472
  else:
473
+ # For altered model, let ConcatModelWithRationale initialize its own CNN modules
474
+ # The CNN modules are created inside __init__, so we just need to create the model
475
+ # and then load the state dict
476
+
477
+ # First, create a dummy projection_mlp - we'll replace it after calculating dims
478
+ # Actually, we need to calculate dims first to create the correct projection_mlp
479
+
480
+ # Calculate dimensions based on inferred parameters
481
+ temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
482
+ msa_out_dim = cnn_num_filters * len(cnn_kernel_sizes)
483
+ proj_input_dim = H + temporal_out_dim + msa_out_dim + H
484
+
485
+ projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
486
+
487
+ model = ConcatModelWithRationale(
488
+ hatebert_model=hatebert_model,
489
+ additional_model=rationale_model,
490
+ projection_mlp=projection_mlp,
491
+ hidden_size=H,
492
+ freeze_additional_model=freeze_rationale,
493
+ cnn_num_filters=cnn_num_filters,
494
+ cnn_kernel_sizes=cnn_kernel_sizes,
495
+ cnn_dropout=cnn_dropout
496
+ ).to(device)
497
 
498
+ # Load state dict with strict checking and error reporting
499
  if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
500
+ state_dict_to_load = checkpoint['model_state_dict']
501
+ else:
502
+ state_dict_to_load = checkpoint
503
+
504
+ # Check for missing and unexpected keys
505
+ model_keys = set(model.state_dict().keys())
506
+ checkpoint_keys = set(state_dict_to_load.keys())
507
+
508
+ missing_keys = model_keys - checkpoint_keys
509
+ unexpected_keys = checkpoint_keys - model_keys
510
+
511
+ if missing_keys:
512
+ print(f"WARNING: Missing keys in checkpoint: {missing_keys}")
513
+ if unexpected_keys:
514
+ print(f"WARNING: Unexpected keys in checkpoint: {unexpected_keys}")
515
+
516
+ # Load with strict=False to handle any minor mismatches, but log warnings
517
+ incompatible_keys = model.load_state_dict(state_dict_to_load, strict=True)
518
+
519
+ if incompatible_keys.missing_keys:
520
+ print(f"Missing keys after load: {incompatible_keys.missing_keys}")
521
+ if incompatible_keys.unexpected_keys:
522
+ print(f"Unexpected keys after load: {incompatible_keys.unexpected_keys}")
523
+
524
+ if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
525
  print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
526
  print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
527
+
528
+ # CRITICAL: Set to eval mode and ensure no gradient computation
529
  model.eval()
530
+
531
+ # Disable dropout explicitly by setting training mode to False for all modules
532
+ for module in model.modules():
533
+ if isinstance(module, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)):
534
+ module.p = 0 # Set dropout probability to 0
535
+
536
  model = model.to(device)
537
 
538
+ # Verify model is in eval mode
539
+ print(f"Model training mode: {model.training}")
540
+ print(f"Dropout layers found: {sum(1 for _ in model.modules() if isinstance(_, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)))}")
541
+
542
  # Create a unified config dict with max_length at top level for compatibility
543
  unified_config = config.copy()
544
  if 'max_length' not in unified_config and 'training_config' in config:
 
546
 
547
  return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
548
 
549
+
550
+ def combined_loss(logits, labels, rationale_probs, selector_logits, rationale_mask=None, attns=None, attn_weight=0.0, rationale_weight=1.0):
551
+ cls_loss = F.cross_entropy(logits, labels)
552
+
553
+ # supervise selector logits with BCE-with-logits against rationale mask (if available)
554
+ if rationale_mask is not None:
555
+ selector_loss = F.binary_cross_entropy_with_logits(selector_logits, rationale_mask.to(selector_logits.device))
556
+ else:
557
+ selector_loss = torch.tensor(0.0, device=cls_loss.device)
558
+
559
+ # optional attention alignment loss (disabled by default)
560
+ attn_loss = torch.tensor(0.0, device=cls_loss.device)
561
+ if attns is not None and attn_weight > 0.0:
562
+ try:
563
+ last_attn = attns[-1] # (B, H, L, L)
564
+ attn_mass = last_attn.mean(1).mean(1) # (B, L)
565
+ attn_loss = F.mse_loss(attn_mass, rationale_mask.to(attn_mass.device))
566
+ except Exception:
567
+ attn_loss = torch.tensor(0.0, device=cls_loss.device)
568
+
569
+ total_loss = cls_loss + rationale_weight * selector_loss + attn_weight * attn_loss
570
+ return total_loss, cls_loss.item(), selector_loss.item(), attn_loss.item()
571
+
572
+
573
  def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
574
  device='cpu', max_length=128, model_type="altered"):
575
+ # Ensure model is in eval mode (defensive programming)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  model.eval()
577
 
578
  # Tokenize inputs
 
598
  add_input_ids = inputs_rationale['input_ids'].to(device)
599
  add_attention_mask = inputs_rationale['attention_mask'].to(device)
600
 
601
+ # Inference with no gradient computation
602
+ with torch.no_grad():
603
+ if model_type.lower() == "base":
604
  logits = model(
605
  input_ids,
606
  attention_mask,
607
  add_input_ids,
608
  add_attention_mask
609
  )
610
+ else:
611
+ logits, rationale_probs, selector_logits, _ = model(
612
+ input_ids,
613
+ attention_mask,
614
+ add_input_ids,
615
+ add_attention_mask
616
+ )
617
 
618
+ temperature = 1 # Adjust this if needed (e.g., 2.0 for less confidence)
619
+ scaled_logits = logits / temperature
620
+
621
+ # Get probabilities with numerical stability
622
+ probs = F.softmax(scaled_logits, dim=1)
623
+
624
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
625
+ print(f"WARNING: NaN or Inf in probabilities. Logits: {logits}")
626
+ # Fallback to uniform distribution
627
+ probs = torch.ones_like(logits) / logits.size(1)
 
 
 
 
 
628
 
 
 
629
  prediction = logits.argmax(dim=1).item()
630
  confidence = probs[0, prediction].item()
631
+
632
+ # Debug: Print logits and probs for first few predictions
633
+ print(f"Debug - Logits: {logits[0].cpu().numpy()}, Probs: {probs[0].cpu().numpy()}")
634
 
635
+ result = {
636
  'prediction': prediction,
637
  'confidence': confidence,
638
  'probabilities': probs[0].cpu().numpy(),
 
639
  'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
640
  }
 
 
 
 
641
 
642
+ if model_type.lower() != "base":
643
+ result['rationale_scores'] = rationale_probs[0].cpu().numpy() if 'rationale_probs' in locals() else None
644
+ else:
645
+ result['rationale_scores'] = None
646
+
647
+ return result
648
+
649
+
650
+ def predict_hatespeech_from_file(
651
+ text_list,
652
+ rationale_list,
653
+ true_label,
654
+ model,
655
+ tokenizer_hatebert,
656
+ tokenizer_rationale,
657
+ config,
658
+ device,
659
+ model_type="altered"
660
+ ):
661
+
662
+ print(f"\nStarting inference for model: {type(model).__name__}")
663
+
664
  predictions = []
665
+ all_probs = []
666
  cpu_percent_list = []
667
  memory_percent_list = []
668
 
669
  process = psutil.Process(os.getpid())
670
+
671
+ # 🔥 GPU synchronization BEFORE timing
672
+ if torch.cuda.is_available():
673
+ torch.cuda.synchronize()
674
+
675
+ # 🔥 Optional warmup (prevents first-batch timing bias)
676
+ with torch.no_grad():
677
+ _ = predict_text(
678
+ text=text_list[0],
679
+ rationale=rationale_list[0],
680
+ model=model,
681
+ tokenizer_hatebert=tokenizer_hatebert,
682
+ tokenizer_rationale=tokenizer_rationale,
683
+ device=device,
684
+ max_length=config.get('max_length', 128),
685
+ model_type=model_type
686
+ )
687
+
688
+ if torch.cuda.is_available():
689
+ torch.cuda.synchronize()
690
+
691
+ # ⏱ Start timer AFTER warmup
692
  start_time = time()
693
+
694
  for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
695
  result = predict_text(
696
  text=text,
 
702
  max_length=config.get('max_length', 128),
703
  model_type=model_type
704
  )
705
+
706
  predictions.append(result['prediction'])
707
+ all_probs.append(result['probabilities'])
708
+
709
+ # Reduce monitoring overhead
710
  if idx % 10 == 0 or idx == len(text_list) - 1:
711
  cpu_percent_list.append(process.cpu_percent())
712
  memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
713
 
714
+ # 🔥 GPU synchronization BEFORE stopping timer
715
+ if torch.cuda.is_available():
716
+ torch.cuda.synchronize()
717
+
718
  end_time = time()
719
  runtime = end_time - start_time
720
+
721
+ print(f"Inference completed for {type(model).__name__}")
722
+ print(f"Total runtime: {runtime:.4f} seconds")
723
+
724
+ # ---------------- Metrics ----------------
725
+ all_probs = np.array(all_probs)
726
+
727
+ print(f"Probability Mean: {all_probs.mean(axis=0)}")
728
+ print(f"Probability Std: {all_probs.std(axis=0)}")
729
+ print(f"Prediction distribution: {np.bincount(predictions, minlength=2)}")
730
+
731
  f1 = f1_score(true_label, predictions, zero_division=0)
732
  accuracy = accuracy_score(true_label, predictions)
733
  precision = precision_score(true_label, predictions, zero_division=0)
734
  recall = recall_score(true_label, predictions, zero_division=0)
735
  cm = confusion_matrix(true_label, predictions).tolist()
736
+
737
  avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
738
+ avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
739
  peak_memory = max(memory_percent_list) if memory_percent_list else 0
740
  peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
741
 
742
  return {
743
+ 'model_name': type(model).__name__, # 👈 makes logs clearer
744
  'f1_score': f1,
745
  'accuracy': accuracy,
746
  'precision': precision,
 
750
  'memory_usage': avg_memory,
751
  'peak_cpu_usage': peak_cpu,
752
  'peak_memory_usage': peak_memory,
753
+ 'runtime': runtime,
754
+ 'all_probabilities': all_probs.tolist()
755
  }
756
 
757
 
758
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
759
  """
760
  Predict hate speech for given text
 
 
 
 
 
 
 
 
 
 
 
 
761
  """
762
  # Get prediction
763
  result = predict_text(
 
772
  )
773
 
774
  return result