niel-2ru commited on
Commit
dcb7c19
·
1 Parent(s): 0cadc48

fix - model v4 models used

Browse files
Files changed (2) hide show
  1. src/app.py +35 -17
  2. src/hatespeech_model.py +357 -316
src/app.py CHANGED
@@ -21,8 +21,14 @@ st.set_page_config(
21
  # Cached model loading function
22
  @st.cache_resource
23
  def load_cached_model(model_type="altered"):
24
- """Load and cache the model"""
25
- return load_model_from_hf(model_type=model_type)
 
 
 
 
 
 
26
 
27
  # Custom CSS
28
  st.markdown("""
@@ -68,8 +74,21 @@ st.markdown('<div class="sub-header">Comparing Base vs Enhanced models with expl
68
  # Load both models with spinner
69
  with st.spinner('🔄 Loading models... This may take a moment on first run.'):
70
  try:
71
- base_model, base_tokenizer_hatebert, base_tokenizer_rationale, base_config, base_device = load_cached_model("base")
72
- enhanced_model, enhanced_tokenizer_hatebert, enhanced_tokenizer_rationale, enhanced_config, enhanced_device = load_cached_model("altered")
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
74
  except Exception as e:
75
  st.error(f"❌ Error loading models: {str(e)}")
@@ -320,30 +339,31 @@ if classify_button:
320
 
321
  # Filter out special tokens and create visualization
322
  token_importance = []
323
- html_output = "<div style='font-size: 16px; line-height: 2.2; padding: 15px; background-color: #f8f9fa; border-radius: 10px;'>"
324
 
325
  for token, score in zip(enhanced_tokens, enhanced_rationale_scores):
326
  if token not in ['[CLS]', '[SEP]', '[PAD]']:
327
  # Clean token
328
  display_token = token.replace('##', '')
329
  token_importance.append({'Token': display_token, 'Importance': score})
330
-
331
  # Color intensity based on score
332
  alpha = min(score * 1.5, 1.0) # Scale up visibility
333
- if enhanced_prediction == 1: # Hate speech
334
- color = f"rgba(239, 83, 80, {alpha:.2f})"
335
- else: # Not hate speech
336
- color = f"rgba(102, 187, 106, {alpha:.2f})"
 
 
 
 
 
337
 
338
  html_output += f"<span style='background-color: {color}; padding: 3px 6px; margin: 1px; border-radius: 4px; display: inline-block;'>{display_token}</span> "
339
 
340
  html_output += "</div>"
341
  st.markdown(html_output, unsafe_allow_html=True)
342
-
343
- if enhanced_prediction == 1:
344
- st.caption("🔴 Darker red = Higher importance for hate speech detection")
345
- else:
346
- st.caption("🟢 Darker green = Higher importance for non-hate speech classification")
347
 
348
  # Top important tokens
349
  st.markdown("**📋 Top Important Tokens**")
@@ -459,7 +479,6 @@ if classify_button:
459
  with enhanced_file_col:
460
  st.subheader("🟢 Enhanced Shield Results")
461
 
462
- # Performance Metrics
463
  st.markdown("**📈 Classification Metrics**")
464
  enh_fm1, enh_fm2 = st.columns(2)
465
  with enh_fm1:
@@ -485,7 +504,6 @@ if classify_button:
485
  fig_enhanced_cm.update_layout(height=300)
486
  st.plotly_chart(fig_enhanced_cm, use_container_width=True)
487
 
488
- # Resource Usage
489
  st.markdown("**⚙️ Resource Usage**")
490
  enh_cpu_col, enh_mem_col = st.columns(2)
491
  with enh_cpu_col:
 
21
  # Cached model loading function
22
  @st.cache_resource
23
  def load_cached_model(model_type="altered"):
24
+ model, tokenizer_hatebert, tokenizer_rationale, config, device = load_model_from_hf(model_type=model_type)
25
+ return {
26
+ "model": model,
27
+ "tokenizer_hatebert": tokenizer_hatebert,
28
+ "tokenizer_rationale": tokenizer_rationale,
29
+ "config": config,
30
+ "device": device
31
+ }
32
 
33
  # Custom CSS
34
  st.markdown("""
 
74
  # Load both models with spinner
75
  with st.spinner('🔄 Loading models... This may take a moment on first run.'):
76
  try:
77
+ base_data = load_cached_model("base")
78
+ enhanced_data = load_cached_model("altered")
79
+
80
+ base_model = base_data["model"]
81
+ base_tokenizer_hatebert = base_data["tokenizer_hatebert"]
82
+ base_tokenizer_rationale = base_data["tokenizer_rationale"]
83
+ base_config = base_data["config"]
84
+ base_device = base_data["device"]
85
+
86
+ enhanced_model = enhanced_data["model"]
87
+ enhanced_tokenizer_hatebert = enhanced_data["tokenizer_hatebert"]
88
+ enhanced_tokenizer_rationale = enhanced_data["tokenizer_rationale"]
89
+ enhanced_config = enhanced_data["config"]
90
+ enhanced_device = enhanced_data["device"]
91
+
92
  st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
93
  except Exception as e:
94
  st.error(f"❌ Error loading models: {str(e)}")
 
339
 
340
  # Filter out special tokens and create visualization
341
  token_importance = []
342
+ html_output = "<div style='font-size: 16px; line-height: 2.2; padding: 15px; background-color: #0E1117; border-radius: 10px;'>"
343
 
344
  for token, score in zip(enhanced_tokens, enhanced_rationale_scores):
345
  if token not in ['[CLS]', '[SEP]', '[PAD]']:
346
  # Clean token
347
  display_token = token.replace('##', '')
348
  token_importance.append({'Token': display_token, 'Importance': score})
349
+ thresholds = [0.25, 0.5, 0.75]
350
  # Color intensity based on score
351
  alpha = min(score * 1.5, 1.0) # Scale up visibility
352
+ color = f"rgba(239, 83, 80, {alpha:.2f})"
353
+ # if score > 0.5:
354
+ # color = f"rgba(239, 83, 80, {alpha:.2f})"
355
+ # else:
356
+ # color = f"rgba(242, 155, 5, {alpha:.2f})"
357
+ # if enhanced_prediction == 1: # Hate speech
358
+ # color = f"rgba(239, 83, 80, {alpha:.2f})"
359
+ # else: # Not hate speech
360
+ # color = f"rgba(102, 187, 106, {alpha:.2f})"
361
 
362
  html_output += f"<span style='background-color: {color}; padding: 3px 6px; margin: 1px; border-radius: 4px; display: inline-block;'>{display_token}</span> "
363
 
364
  html_output += "</div>"
365
  st.markdown(html_output, unsafe_allow_html=True)
366
+ st.caption("🔴 Darker red = More influence on hate speech detection.")
 
 
 
 
367
 
368
  # Top important tokens
369
  st.markdown("**📋 Top Important Tokens**")
 
479
  with enhanced_file_col:
480
  st.subheader("🟢 Enhanced Shield Results")
481
 
 
482
  st.markdown("**📈 Classification Metrics**")
483
  enh_fm1, enh_fm2 = st.columns(2)
484
  with enh_fm1:
 
504
  fig_enhanced_cm.update_layout(height=300)
505
  st.plotly_chart(fig_enhanced_cm, use_container_width=True)
506
 
 
507
  st.markdown("**⚙️ Resource Usage**")
508
  enh_cpu_col, enh_mem_col = st.columns(2)
509
  with enh_cpu_col:
src/hatespeech_model.py CHANGED
@@ -1,6 +1,5 @@
1
  from huggingface_hub import hf_hub_download
2
  import torch
3
- from torch.cuda import device
4
  from torch.nn import functional as F
5
  import torch.nn as nn
6
  import json
@@ -9,133 +8,183 @@ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_sc
9
  from time import time
10
  import psutil
11
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Model Architecture Classes
 
 
14
  class TemporalCNN(nn.Module):
15
- """
16
- Temporal CNN applied across the sequence (time) dimension.
17
- Input: sequence_embeddings (B, L, H), attention_mask (B, L)
18
- Output: pooled vector (B, output_dim) where output_dim = num_filters * len(kernel_sizes) * 2
19
- (we concatenate max-pooled and mean-pooled features for each kernel size)
20
- """
21
- def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2, 3, 4), dropout=0.3):
22
  super().__init__()
23
  self.input_dim = input_dim
24
  self.num_filters = num_filters
25
  self.kernel_sizes = kernel_sizes
26
-
27
- # Convs expect (B, C_in, L) where C_in = input_dim
28
  self.convs = nn.ModuleList([
29
- nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k // 2)
30
  for k in kernel_sizes
31
  ])
32
  self.dropout = nn.Dropout(dropout)
33
 
34
  def forward(self, sequence_embeddings, attention_mask=None):
35
- """
36
- sequence_embeddings: (B, L, H)
37
- attention_mask: (B, L) with 1 for valid tokens, 0 for padding
38
- returns: (B, num_filters * len(kernel_sizes) * 2) # max + mean pooled per conv
39
- """
40
- # transpose to (B, H, L)
41
- x = sequence_embeddings.transpose(1, 2).contiguous() # (B, H, L)
42
-
43
  pooled_outputs = []
44
  for conv in self.convs:
45
- conv_out = conv(x) # (B, num_filters, L_out)
46
  conv_out = F.relu(conv_out)
47
  L_out = conv_out.size(2)
48
-
49
  if attention_mask is not None:
50
- # resize mask to match L_out
51
  mask = attention_mask.float()
52
  if mask.size(1) != L_out:
53
  mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
54
  mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
55
-
56
- # max pool with masking
57
  neg_inf = torch.finfo(conv_out.dtype).min / 2
58
- max_masked = torch.where(mask.bool(), conv_out, neg_inf * torch.ones_like(conv_out))
59
  max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
60
-
61
- # mean pool with masking
62
  sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
63
  denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
64
  mean_pooled = sum_masked / denom # (B, num_filters)
65
  else:
66
  max_pooled = torch.max(conv_out, dim=2)[0]
67
  mean_pooled = conv_out.mean(dim=2)
68
-
69
  pooled_outputs.append(max_pooled)
70
  pooled_outputs.append(mean_pooled)
71
-
72
- out = torch.cat(pooled_outputs, dim=1) # (B, num_filters * len(kernel_sizes) * 2)
73
  out = self.dropout(out)
74
  return out
75
 
76
 
77
  class MultiScaleAttentionCNN(nn.Module):
78
- def __init__(self, hidden_size=768, num_filters=64, kernel_sizes=(2, 3, 4, 5, 6, 7), dropout=0.3):
79
  super().__init__()
80
 
81
  self.hidden_size = hidden_size
82
  self.kernel_sizes = kernel_sizes
83
 
84
  self.convs = nn.ModuleList()
85
- self.pads = nn.ModuleList()
86
 
87
  for k in self.kernel_sizes:
88
- pad_left = (k - 1) // 2
89
  pad_right = k - 1 - pad_left
 
90
  self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
91
- self.convs.append(nn.Conv1d(hidden_size, num_filters, kernel_size=k, padding=0))
 
 
 
92
 
93
  self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
94
  self.output_size = num_filters * len(self.kernel_sizes)
95
  self.dropout = nn.Dropout(dropout)
96
 
97
  def forward(self, hidden_states, mask):
98
- """
99
- hidden_states: (B, L, H)
100
- mask: (B, L)
101
- """
102
- x = hidden_states.transpose(1, 2) # (B, H, L)
103
  attn_mask = mask.unsqueeze(1).float()
104
 
105
  conv_outs = []
106
 
107
  for pad, conv, att in zip(self.pads, self.convs, self.attn):
108
- padded = pad(x) # (B, H, L)
109
- c = conv(padded) # (B, F, L)
110
  c = F.relu(c)
111
  c = c * attn_mask
112
 
113
- c_t = c.transpose(1, 2) # (B, L, F)
114
- w = att(c_t) # (B, L, 1)
115
  w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
116
  w = F.softmax(w, dim=1)
117
 
118
- pooled = (c_t * w).sum(dim=1) # (B, F)
119
  conv_outs.append(pooled)
120
 
121
- out = torch.cat(conv_outs, dim=1) # (B, F * K)
122
  return self.dropout(out)
123
-
124
-
125
- class ProjectionMLP(nn.Module):
126
- def __init__(self, input_size, hidden_size=256, num_labels=2):
127
- super().__init__()
128
- self.layers = nn.Sequential(
129
- nn.Linear(input_size, hidden_size),
130
- nn.ReLU(),
131
- nn.Linear(hidden_size, num_labels)
132
- )
133
-
134
- def forward(self, x):
135
- return self.layers(x)
136
-
137
-
138
  class ConcatModelWithRationale(nn.Module):
 
139
  def __init__(self,
140
  hatebert_model,
141
  additional_model,
@@ -143,86 +192,153 @@ class ConcatModelWithRationale(nn.Module):
143
  hidden_size=768,
144
  gumbel_temp=0.5,
145
  freeze_additional_model=True,
146
- cnn_num_filters=64,
147
- cnn_kernel_sizes=(2, 3, 4, 5, 6, 7),
148
- cnn_dropout=0.0):
 
149
  super().__init__()
 
150
  self.hatebert_model = hatebert_model
151
  self.additional_model = additional_model
152
  self.projection_mlp = projection_mlp
153
  self.gumbel_temp = gumbel_temp
154
  self.hidden_size = hidden_size
155
 
 
 
 
 
 
 
156
  if freeze_additional_model:
157
  for param in self.additional_model.parameters():
158
  param.requires_grad = False
159
 
160
- # selector head (per-token logits)
161
  self.selector = nn.Linear(hidden_size, 1)
162
 
163
- # Temporal CNN over HateBERT embeddings (main text)
164
- self.temporal_cnn = TemporalCNN(input_dim=hidden_size,
165
- num_filters=cnn_num_filters,
166
- kernel_sizes=cnn_kernel_sizes,
167
- dropout=cnn_dropout)
 
 
168
  self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
169
 
170
- # MultiScaleAttentionCNN over rationale embeddings (frozen BERT)
171
- self.msa_cnn = MultiScaleAttentionCNN(hidden_size=hidden_size,
172
- num_filters=cnn_num_filters,
173
- kernel_sizes=cnn_kernel_sizes,
174
- dropout=cnn_dropout)
 
 
175
  self.msa_out_dim = self.msa_cnn.output_size
176
 
 
177
  def gumbel_sigmoid_sample(self, logits):
178
  noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
179
  y = logits + noise
180
  return torch.sigmoid(y / self.gumbel_temp)
181
 
182
- def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask, return_attentions=False):
183
- # Main text through HateBERT
184
- hatebert_out = self.hatebert_model(input_ids=input_ids,
185
- attention_mask=attention_mask,
186
- output_attentions=return_attentions,
187
- return_dict=True)
188
- hatebert_emb = hatebert_out.last_hidden_state # (B, L, H)
189
- cls_emb = hatebert_emb[:, 0, :] # (B, H)
190
 
191
- # Rationale text through frozen BERT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  with torch.no_grad():
193
- add_out = self.additional_model(input_ids=additional_input_ids,
194
- attention_mask=additional_attention_mask,
195
- return_dict=True)
196
- rationale_emb = add_out.last_hidden_state # (B, L, H)
197
 
198
- # selector logits & Gumbel-Sigmoid sampling on HateBERT
199
- selector_logits = self.selector(hatebert_emb).squeeze(-1) # (B, L)
200
- rationale_probs = self.gumbel_sigmoid_sample(selector_logits) # (B, L)
201
- rationale_probs = rationale_probs * attention_mask.float().to(rationale_probs.device)
 
202
 
203
- # pooled rationale summary
204
- masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
205
- denom = rationale_probs.sum(1).unsqueeze(-1).clamp_min(1e-6)
206
- pooled_rationale = masked_hidden.sum(1) / denom # (B, H)
207
 
208
- # CNN branches
209
- temporal_features = self.temporal_cnn(hatebert_emb, attention_mask) # (B, temporal_out_dim)
210
- rationale_features = self.msa_cnn(rationale_emb, additional_attention_mask) # (B, msa_out_dim)
 
 
 
211
 
212
- # concat CLS + CNN features + pooled rationale
213
- concat_emb = torch.cat((cls_emb, temporal_features, rationale_features, pooled_rationale), dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  logits = self.projection_mlp(concat_emb)
216
 
217
- attns = hatebert_out.attentions if (return_attentions and hasattr(hatebert_out, "attentions")) else None
 
 
 
218
  return logits, rationale_probs, selector_logits, attns
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  class BaseShield(nn.Module):
222
- """
223
- Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
224
- and projects to label logits via a small MLP.
225
- """
226
  def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
227
  super().__init__()
228
  self.hatebert_model = hatebert_model
@@ -235,249 +351,189 @@ class BaseShield(nn.Module):
235
  param.requires_grad = False
236
 
237
  def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
238
- hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
239
  hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
240
 
241
- additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask, return_dict=True)
242
  additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
243
 
244
  concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
245
  logits = self.projection_mlp(concatenated_embeddings)
246
  return logits
247
-
248
- def load_model_from_hf(model_type="altered"):
249
- """
250
- Load model from Hugging Face Hub
251
 
252
- Args:
253
- model_type: Either "altered" or "base" to choose which model to load
254
- """
255
 
256
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
257
  repo_id = "seffyehl/BetterShield"
258
-
259
- # Choose model and config files based on model_type
260
  if model_type.lower() == "altered":
261
  model_filename = "AlteredShield.pth"
262
- config_filename = "alter_config.json"
263
  elif model_type.lower() == "base":
264
- model_filename = "BaselineShield.pth"
265
- config_filename = "base_config.json"
266
  else:
267
- raise ValueError(f"model_type must be 'altered' or 'base', got '{model_type}'")
268
-
269
- # Download files
270
- model_path = hf_hub_download(
271
- repo_id=repo_id,
272
- filename=model_filename
273
- )
274
-
275
- config_path = hf_hub_download(
276
- repo_id=repo_id,
277
- filename=config_filename,
278
- )
279
-
280
- # Load config
281
- with open(config_path, 'r') as f:
282
- config = json.load(f)
283
-
284
- # Load checkpoint
285
- checkpoint = torch.load(model_path, map_location='cpu')
286
-
287
- # Handle nested config structure (base model uses model_config, altered uses flat structure)
288
- if 'model_config' in config:
289
- model_config = config['model_config']
290
- training_config = config.get('training_config', {})
291
  else:
292
- model_config = config
293
- training_config = config
294
-
295
- # Initialize base models
296
- hatebert_model = AutoModel.from_pretrained(model_config['hatebert_model'])
297
- rationale_model = AutoModel.from_pretrained(model_config['rationale_model'])
298
-
299
- tokenizer_hatebert = AutoTokenizer.from_pretrained(model_config['hatebert_model'])
300
- tokenizer_rationale = AutoTokenizer.from_pretrained(model_config['rationale_model'])
301
-
302
- # Rebuild architecture based on model type using training_config values when available
303
  H = hatebert_model.config.hidden_size
304
- max_length = training_config.get('max_length', 128)
305
-
306
- if model_type.lower() == "base":
307
- proj_input_dim = H * 2
308
- projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  model = BaseShield(
310
  hatebert_model=hatebert_model,
311
  additional_model=rationale_model,
312
  projection_mlp=projection_mlp,
313
- freeze_additional_model=freeze_rationale,
314
  device=device
315
- ).to(device)
 
316
  else:
317
- # For altered model, let ConcatModelWithRationale initialize its own CNN modules
318
- # The CNN modules are created inside __init__, so we just need to create the model
319
- # and then load the state dict
320
-
321
- # First, create a dummy projection_mlp - we'll replace it after calculating dims
322
- # Actually, we need to calculate dims first to create the correct projection_mlp
323
-
324
- # Calculate dimensions based on inferred parameters
325
- temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
326
- msa_out_dim = cnn_num_filters * len(cnn_kernel_sizes)
327
- proj_input_dim = H + temporal_out_dim + msa_out_dim + H
328
-
329
- projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
330
-
331
  model = ConcatModelWithRationale(
332
  hatebert_model=hatebert_model,
333
  additional_model=rationale_model,
334
  projection_mlp=projection_mlp,
335
  hidden_size=H,
336
- freeze_additional_model=freeze_rationale,
337
  cnn_num_filters=cnn_num_filters,
338
  cnn_kernel_sizes=cnn_kernel_sizes,
339
  cnn_dropout=cnn_dropout
340
- ).to(device)
341
-
342
- # Load state dict with strict checking and error reporting
343
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
344
- state_dict_to_load = checkpoint['model_state_dict']
345
- else:
346
- state_dict_to_load = checkpoint
347
-
348
- # Check for missing and unexpected keys
349
- model_keys = set(model.state_dict().keys())
350
- checkpoint_keys = set(state_dict_to_load.keys())
351
-
352
- missing_keys = model_keys - checkpoint_keys
353
- unexpected_keys = checkpoint_keys - model_keys
354
-
355
- if missing_keys:
356
- print(f"WARNING: Missing keys in checkpoint: {missing_keys}")
357
- if unexpected_keys:
358
- print(f"WARNING: Unexpected keys in checkpoint: {unexpected_keys}")
359
-
360
- # Load with strict=False to handle any minor mismatches, but log warnings
361
- incompatible_keys = model.load_state_dict(state_dict_to_load, strict=True)
362
-
363
- if incompatible_keys.missing_keys:
364
- print(f"Missing keys after load: {incompatible_keys.missing_keys}")
365
- if incompatible_keys.unexpected_keys:
366
- print(f"Unexpected keys after load: {incompatible_keys.unexpected_keys}")
367
-
368
- if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
369
- print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
370
- print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
371
-
372
- # CRITICAL: Set to eval mode and ensure no gradient computation
373
- model.eval()
374
-
375
- # Disable dropout explicitly by setting training mode to False for all modules
376
- for module in model.modules():
377
- if isinstance(module, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)):
378
- module.p = 0 # Set dropout probability to 0
379
-
380
- model = model.to(device)
381
-
382
- # Verify model is in eval mode
383
- print(f"Model training mode: {model.training}")
384
- print(f"Dropout layers found: {sum(1 for _ in model.modules() if isinstance(_, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)))}")
385
-
386
- # Create a unified config dict with max_length at top level for compatibility
387
- unified_config = config.copy()
388
- if 'max_length' not in unified_config and 'training_config' in config:
389
- unified_config['max_length'] = training_config.get('max_length', 128)
390
-
391
- return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
392
-
393
 
394
- def combined_loss(logits, labels, rationale_probs, selector_logits, rationale_mask=None, attns=None, attn_weight=0.0, rationale_weight=1.0):
395
- cls_loss = F.cross_entropy(logits, labels)
396
 
397
- # supervise selector logits with BCE-with-logits against rationale mask (if available)
398
- if rationale_mask is not None:
399
- selector_loss = F.binary_cross_entropy_with_logits(selector_logits, rationale_mask.to(selector_logits.device))
400
- else:
401
- selector_loss = torch.tensor(0.0, device=cls_loss.device)
402
 
403
- # optional attention alignment loss (disabled by default)
404
- attn_loss = torch.tensor(0.0, device=cls_loss.device)
405
- if attns is not None and attn_weight > 0.0:
406
- try:
407
- last_attn = attns[-1] # (B, H, L, L)
408
- attn_mass = last_attn.mean(1).mean(1) # (B, L)
409
- attn_loss = F.mse_loss(attn_mass, rationale_mask.to(attn_mass.device))
410
- except Exception:
411
- attn_loss = torch.tensor(0.0, device=cls_loss.device)
412
 
413
- total_loss = cls_loss + rationale_weight * selector_loss + attn_weight * attn_loss
414
- return total_loss, cls_loss.item(), selector_loss.item(), attn_loss.item()
415
 
 
 
 
 
 
 
 
 
 
 
416
 
417
- def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
418
- device='cpu', max_length=128, model_type="altered"):
419
- # Ensure model is in eval mode (defensive programming)
420
  model.eval()
421
-
422
- # Tokenize inputs
423
- inputs_main = tokenizer_hatebert(
424
  text,
425
  max_length=max_length,
426
- padding='max_length',
427
  truncation=True,
428
- return_tensors='pt'
429
  )
430
-
431
- inputs_rationale = tokenizer_rationale(
432
- rationale if rationale else text, # Use text if no rationale provided
433
  max_length=max_length,
434
- padding='max_length',
435
  truncation=True,
436
- return_tensors='pt'
437
  )
438
-
439
- # Move to device
440
- input_ids = inputs_main['input_ids'].to(device)
441
- attention_mask = inputs_main['attention_mask'].to(device)
442
- add_input_ids = inputs_rationale['input_ids'].to(device)
443
- add_attention_mask = inputs_rationale['attention_mask'].to(device)
444
-
445
- # Inference with no gradient computation
446
  with torch.no_grad():
 
447
  if model_type.lower() == "base":
448
  logits = model(
449
- input_ids,
450
- attention_mask,
451
- add_input_ids,
452
  add_attention_mask
453
  )
 
454
  else:
455
- logits, rationale_probs, selector_logits, _ = model(
456
- input_ids,
457
- attention_mask,
458
- add_input_ids,
459
  add_attention_mask
460
  )
461
-
462
- # Get probabilities
463
- probs = torch.softmax(logits, dim=1)
 
 
 
 
 
 
 
 
 
 
 
464
  prediction = logits.argmax(dim=1).item()
465
  confidence = probs[0, prediction].item()
466
-
467
  return {
468
- 'prediction': prediction,
469
- 'confidence': confidence,
470
- 'probabilities': probs[0].cpu().numpy(),
471
- 'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
 
472
  }
473
-
474
- if model_type.lower() != "base":
475
- result['rationale_scores'] = rationale_probs[0].cpu().numpy() if 'rationale_probs' in locals() else None
476
- else:
477
- result['rationale_scores'] = None
478
-
479
- return result
480
-
481
 
482
  def predict_hatespeech_from_file(
483
  text_list,
@@ -500,11 +556,10 @@ def predict_hatespeech_from_file(
500
 
501
  process = psutil.Process(os.getpid())
502
 
503
- # 🔥 GPU synchronization BEFORE timing
504
  if torch.cuda.is_available():
505
  torch.cuda.synchronize()
506
 
507
- # 🔥 Optional warmup (prevents first-batch timing bias)
508
  with torch.no_grad():
509
  _ = predict_text(
510
  text=text_list[0],
@@ -520,10 +575,10 @@ def predict_hatespeech_from_file(
520
  if torch.cuda.is_available():
521
  torch.cuda.synchronize()
522
 
523
- # ⏱ Start timer AFTER warmup
524
  start_time = time()
525
 
526
  for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
 
527
  result = predict_text(
528
  text=text,
529
  rationale=rationale,
@@ -538,28 +593,20 @@ def predict_hatespeech_from_file(
538
  predictions.append(result['prediction'])
539
  all_probs.append(result['probabilities'])
540
 
541
- # Reduce monitoring overhead
542
  if idx % 10 == 0 or idx == len(text_list) - 1:
543
  cpu_percent_list.append(process.cpu_percent())
544
  memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
545
 
546
- # 🔥 GPU synchronization BEFORE stopping timer
547
  if torch.cuda.is_available():
548
  torch.cuda.synchronize()
549
 
550
- end_time = time()
551
- runtime = end_time - start_time
552
 
553
  print(f"Inference completed for {type(model).__name__}")
554
  print(f"Total runtime: {runtime:.4f} seconds")
555
 
556
- # ---------------- Metrics ----------------
557
  all_probs = np.array(all_probs)
558
 
559
- print(f"Probability Mean: {all_probs.mean(axis=0)}")
560
- print(f"Probability Std: {all_probs.std(axis=0)}")
561
- print(f"Prediction distribution: {np.bincount(predictions, minlength=2)}")
562
-
563
  f1 = f1_score(true_label, predictions, zero_division=0)
564
  accuracy = accuracy_score(true_label, predictions)
565
  precision = precision_score(true_label, predictions, zero_division=0)
@@ -572,7 +619,7 @@ def predict_hatespeech_from_file(
572
  peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
573
 
574
  return {
575
- 'model_name': type(model).__name__, # 👈 makes logs clearer
576
  'f1_score': f1,
577
  'accuracy': accuracy,
578
  'precision': precision,
@@ -585,14 +632,10 @@ def predict_hatespeech_from_file(
585
  'runtime': runtime,
586
  'all_probabilities': all_probs.tolist()
587
  }
588
-
589
-
590
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
591
- """
592
- Predict hate speech for given text
593
- """
594
- # Get prediction
595
- result = predict_text(
596
  text=text,
597
  rationale=rationale,
598
  model=model,
@@ -601,6 +644,4 @@ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rat
601
  device=device,
602
  max_length=config.get('max_length', 128),
603
  model_type=model_type
604
- )
605
-
606
- return result
 
1
  from huggingface_hub import hf_hub_download
2
  import torch
 
3
  from torch.nn import functional as F
4
  import torch.nn as nn
5
  import json
 
8
  from time import time
9
  import psutil
10
  import os
11
+ import numpy as np
12
+ import requests
13
+ import json
14
+ from dotenv import load_dotenv
15
+
16
+ load_dotenv()
17
+
18
+ API_BASE_URL = os.getenv("CLOUDFLARE_API_BASE_URL")
19
+ HEADERS = {"Authorization": f"Bearer {os.getenv('CLOUDFLARE_API_TOKEN')}"}
20
+ MODEL_NAME = os.getenv("CLOUDFLARE_MODEL_NAME")
21
+
22
+ def create_prompt(text):
23
+ return f"""
24
+ You are a content moderation assistant. Identify the list of [rationales] words or phrases from the text that make it hateful,
25
+ list of [derogatory language], and [list of cuss words] and [hate_classification] such as "hateful" or "non-hateful".
26
+ If there are none, respond exactly with [non-hateful] only.
27
+ Output should be in JSON format only. Text: {text}.
28
+ """
29
+
30
+ def run_mistral_model(model, inputs):
31
+ payload = {"messages": inputs}
32
+ response = requests.post(f"{API_BASE_URL}{model}", headers=HEADERS, json=payload)
33
+ response.raise_for_status()
34
+ return response.json()
35
+
36
+ def flatten_json_string(json_string):
37
+ try:
38
+ obj = json.loads(json_string)
39
+ return json.dumps(obj, separators=(",", ":"))
40
+ except:
41
+ return json_string
42
+
43
+ def get_rationale_from_mistral(text, retries=10):
44
+ for attempt in range(retries):
45
+ try:
46
+ inputs = [{"role": "user", "content": create_prompt(text)}]
47
+ output = run_mistral_model(MODEL_NAME, inputs)
48
+
49
+ result = output.get("result", {})
50
+ response_text = result.get("response", "").strip()
51
+
52
+ if not response_text or response_text.startswith("I cannot"):
53
+ print(f"⚠️ Model returned 'I cannot...' — retrying ({attempt+1}/{retries})")
54
+ continue # retry
55
+ cleaned_rationale = flatten_json_string(response_text).replace("\n", " ").strip()
56
+ return cleaned_rationale
57
+
58
+ except requests.exceptions.HTTPError as e:
59
+ print(f"⚠️ HTTP Error on attempt {attempt+1}: {e}")
60
+ if "RESOURCE_EXHAUSTED" in str(e) or e.response.status_code == 429:
61
+ raise
62
+
63
+ return "non-hateful"
64
+
65
+ def preprocess_rationale_mistral(raw_rationale):
66
+ try:
67
+ x = str(raw_rationale).strip()
68
+
69
+ if x.startswith("```"):
70
+ x = x.replace("```json", "").replace("```", "").strip()
71
+
72
+ x = x.replace('""', '"')
73
+
74
+ # Extract JSON object
75
+ start = x.find("{")
76
+ end = x.rfind("}") + 1
77
+ if start == -1 or end == -1:
78
+ return x.lower()
79
+
80
+ j = json.loads(x[start:end])
81
+
82
+ keys = ["rationales", "derogatory_language", "cuss_words"]
83
+
84
+ if all(k in j and isinstance(j[k], list) and len(j[k]) == 0 for k in keys):
85
+ return "non-hateful"
86
+
87
+ cleaned = {k: j.get(k, []) for k in keys}
88
+ return json.dumps(cleaned).lower()
89
 
90
+ except Exception:
91
+ return str(raw_rationale).lower()
92
+
93
  class TemporalCNN(nn.Module):
94
+ def __init__(self, input_dim=768, num_filters=32, kernel_sizes=(3,4,5), dropout=0.3):
 
 
 
 
 
 
95
  super().__init__()
96
  self.input_dim = input_dim
97
  self.num_filters = num_filters
98
  self.kernel_sizes = kernel_sizes
 
 
99
  self.convs = nn.ModuleList([
100
+ nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k//2)
101
  for k in kernel_sizes
102
  ])
103
  self.dropout = nn.Dropout(dropout)
104
 
105
  def forward(self, sequence_embeddings, attention_mask=None):
106
+ x = sequence_embeddings.transpose(1, 2).contiguous()
107
+
 
 
 
 
 
 
108
  pooled_outputs = []
109
  for conv in self.convs:
110
+ conv_out = conv(x)
111
  conv_out = F.relu(conv_out)
112
  L_out = conv_out.size(2)
113
+
114
  if attention_mask is not None:
 
115
  mask = attention_mask.float()
116
  if mask.size(1) != L_out:
117
  mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
118
  mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
119
+
 
120
  neg_inf = torch.finfo(conv_out.dtype).min / 2
121
+ max_masked = torch.where(mask.bool(), conv_out, neg_inf*torch.ones_like(conv_out))
122
  max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
123
+
 
124
  sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
125
  denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
126
  mean_pooled = sum_masked / denom # (B, num_filters)
127
  else:
128
  max_pooled = torch.max(conv_out, dim=2)[0]
129
  mean_pooled = conv_out.mean(dim=2)
130
+
131
  pooled_outputs.append(max_pooled)
132
  pooled_outputs.append(mean_pooled)
133
+
134
+ out = torch.cat(pooled_outputs, dim=1)
135
  out = self.dropout(out)
136
  return out
137
 
138
 
139
  class MultiScaleAttentionCNN(nn.Module):
140
+ def __init__(self, hidden_size=768, num_filters=32, kernel_sizes=(3,4,5), dropout=0.3):
141
  super().__init__()
142
 
143
  self.hidden_size = hidden_size
144
  self.kernel_sizes = kernel_sizes
145
 
146
  self.convs = nn.ModuleList()
147
+ self.pads = nn.ModuleList()
148
 
149
  for k in self.kernel_sizes:
150
+ pad_left = (k - 1) // 2
151
  pad_right = k - 1 - pad_left
152
+
153
  self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
154
+
155
+ self.convs.append(
156
+ nn.Conv1d(hidden_size, num_filters, kernel_size=k, padding=0)
157
+ )
158
 
159
  self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
160
  self.output_size = num_filters * len(self.kernel_sizes)
161
  self.dropout = nn.Dropout(dropout)
162
 
163
  def forward(self, hidden_states, mask):
164
+ x = hidden_states.transpose(1, 2)
 
 
 
 
165
  attn_mask = mask.unsqueeze(1).float()
166
 
167
  conv_outs = []
168
 
169
  for pad, conv, att in zip(self.pads, self.convs, self.attn):
170
+ padded = pad(x)
171
+ c = conv(padded)
172
  c = F.relu(c)
173
  c = c * attn_mask
174
 
175
+ c_t = c.transpose(1, 2)
176
+ w = att(c_t)
177
  w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
178
  w = F.softmax(w, dim=1)
179
 
180
+ pooled = (c_t * w).sum(dim=1)
181
  conv_outs.append(pooled)
182
 
183
+ out = torch.cat(conv_outs, dim=1)
184
  return self.dropout(out)
185
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  class ConcatModelWithRationale(nn.Module):
187
+
188
  def __init__(self,
189
  hatebert_model,
190
  additional_model,
 
192
  hidden_size=768,
193
  gumbel_temp=0.5,
194
  freeze_additional_model=True,
195
+ cnn_num_filters=128,
196
+ cnn_kernel_sizes=(3,4,5),
197
+ cnn_dropout=0.3):
198
+
199
  super().__init__()
200
+
201
  self.hatebert_model = hatebert_model
202
  self.additional_model = additional_model
203
  self.projection_mlp = projection_mlp
204
  self.gumbel_temp = gumbel_temp
205
  self.hidden_size = hidden_size
206
 
207
+ for param in self.hatebert_model.embeddings.parameters():
208
+ param.requires_grad = False
209
+
210
+ for layer in self.hatebert_model.encoder.layer[:8]:
211
+ for param in layer.parameters():
212
+ param.requires_grad = False
213
  if freeze_additional_model:
214
  for param in self.additional_model.parameters():
215
  param.requires_grad = False
216
 
 
217
  self.selector = nn.Linear(hidden_size, 1)
218
 
219
+ self.temporal_cnn = TemporalCNN(
220
+ input_dim=hidden_size,
221
+ num_filters=cnn_num_filters,
222
+ kernel_sizes=cnn_kernel_sizes,
223
+ dropout=cnn_dropout
224
+ )
225
+
226
  self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
227
 
228
+ self.msa_cnn = MultiScaleAttentionCNN(
229
+ hidden_size=hidden_size,
230
+ num_filters=cnn_num_filters,
231
+ kernel_sizes=cnn_kernel_sizes,
232
+ dropout=cnn_dropout
233
+ )
234
+
235
  self.msa_out_dim = self.msa_cnn.output_size
236
 
237
+
238
  def gumbel_sigmoid_sample(self, logits):
239
  noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
240
  y = logits + noise
241
  return torch.sigmoid(y / self.gumbel_temp)
242
 
 
 
 
 
 
 
 
 
243
 
244
+ def forward(self,
245
+ input_ids,
246
+ attention_mask,
247
+ additional_input_ids,
248
+ additional_attention_mask,
249
+ return_attentions=False):
250
+ hatebert_out = self.hatebert_model(
251
+ input_ids=input_ids,
252
+ attention_mask=attention_mask,
253
+ output_attentions=return_attentions,
254
+ return_dict=True
255
+ )
256
+
257
+ hatebert_emb = hatebert_out.last_hidden_state
258
+ cls_emb = hatebert_emb[:, 0, :]
259
+
260
  with torch.no_grad():
 
 
 
 
261
 
262
+ add_out = self.additional_model(
263
+ input_ids=additional_input_ids,
264
+ attention_mask=additional_attention_mask,
265
+ return_dict=True
266
+ )
267
 
268
+ rationale_emb = add_out.last_hidden_state
269
+
270
+ selector_logits = self.selector(hatebert_emb).squeeze(-1)
 
271
 
272
+ if self.training:
273
+ rationale_probs = self.gumbel_sigmoid_sample(selector_logits)
274
+ else:
275
+ rationale_probs = torch.sigmoid(selector_logits)
276
+
277
+ rationale_probs = rationale_probs * attention_mask.float()
278
 
279
+ masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
280
+ denom = rationale_probs.sum(dim=1).unsqueeze(-1).clamp_min(1e-6)
281
+ pooled_rationale = masked_hidden.sum(dim=1) / denom
282
+
283
+ temporal_features = self.temporal_cnn(
284
+ hatebert_emb,
285
+ attention_mask
286
+ )
287
+
288
+ rationale_features = self.msa_cnn(
289
+ rationale_emb,
290
+ additional_attention_mask
291
+ )
292
+ concat_emb = torch.cat(
293
+ (cls_emb,
294
+ temporal_features,
295
+ rationale_features,
296
+ pooled_rationale),
297
+ dim=1
298
+ )
299
 
300
  logits = self.projection_mlp(concat_emb)
301
 
302
+ attns = None
303
+ if return_attentions and hasattr(hatebert_out, "attentions"):
304
+ attns = hatebert_out.attentions
305
+
306
  return logits, rationale_probs, selector_logits, attns
307
+
308
+ class ProjectionMLP(nn.Module):
309
+ def __init__(self, input_size, hidden_size=128, num_labels=2):
310
+ super().__init__()
311
+
312
+ self.layers = nn.Sequential(
313
+ nn.Linear(input_size, 512),
314
+ nn.LayerNorm(512),
315
+ nn.ReLU(),
316
+ nn.Dropout(0.3),
317
+
318
+ nn.Linear(512, hidden_size),
319
+ nn.ReLU(),
320
+ nn.Dropout(0.3),
321
+
322
+ nn.Linear(hidden_size, num_labels)
323
+ )
324
+
325
+ def forward(self, x):
326
+ return self.layers(x)
327
 
328
 
329
+ class ProjectionMLPBase(nn.Module):
330
+ def __init__(self, input_size, output_size):
331
+ super(ProjectionMLPBase, self).__init__()
332
+ self.layers = nn.Sequential(
333
+ nn.Linear(input_size, output_size),
334
+ nn.ReLU(),
335
+ nn.Linear(output_size, 2)
336
+ )
337
+
338
+ def forward(self, x):
339
+ return self.layers(x)
340
+
341
  class BaseShield(nn.Module):
 
 
 
 
342
  def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
343
  super().__init__()
344
  self.hatebert_model = hatebert_model
 
351
  param.requires_grad = False
352
 
353
  def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
354
+ hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
355
  hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
356
 
357
+ additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
358
  additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
359
 
360
  concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
361
  logits = self.projection_mlp(concatenated_embeddings)
362
  return logits
 
 
 
 
363
 
364
+
 
 
365
 
366
+ def load_model_from_hf(model_type="altered"):
367
+
368
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
369
  repo_id = "seffyehl/BetterShield"
370
+
 
371
  if model_type.lower() == "altered":
372
  model_filename = "AlteredShield.pth"
 
373
  elif model_type.lower() == "base":
374
+ model_filename = "BaseShield.pth"
 
375
  else:
376
+ raise ValueError("model_type must be 'base' or 'altered'")
377
+
378
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
379
+
380
+ checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
381
+
382
+ if "model_state_dict" in checkpoint:
383
+ state_dict = checkpoint["model_state_dict"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  else:
385
+ state_dict = checkpoint
386
+
387
+ hatebert_name = "GroNLP/hateBERT"
388
+ rationale_name = "bert-base-uncased"
389
+
390
+ hatebert_model = AutoModel.from_pretrained(hatebert_name)
391
+ rationale_model = AutoModel.from_pretrained(rationale_name)
392
+
393
+ tokenizer_hatebert = AutoTokenizer.from_pretrained(hatebert_name)
394
+ tokenizer_rationale = AutoTokenizer.from_pretrained(rationale_name)
395
+
396
  H = hatebert_model.config.hidden_size
397
+ first_layer_weight = state_dict["projection_mlp.layers.0.weight"]
398
+ second_layer_weight = state_dict["projection_mlp.layers.4.weight"]
399
+ classifier_weight = state_dict["projection_mlp.layers.7.weight"]
400
+
401
+ input_dim = first_layer_weight.shape[1]
402
+ hidden_dim = second_layer_weight.shape[0]
403
+ num_labels = classifier_weight.shape[0]
404
+
405
+ temporal_keys = [k for k in state_dict if k.startswith("temporal_cnn.convs")]
406
+
407
+ is_altered = len(temporal_keys) > 0
408
+
409
+
410
+ if not is_altered or model_type.lower() == "base":
411
+
412
+ projection_mlp = ProjectionMLPBase(
413
+ input_size=input_dim,
414
+ output_size=hidden_dim
415
+ )
416
+
417
  model = BaseShield(
418
  hatebert_model=hatebert_model,
419
  additional_model=rationale_model,
420
  projection_mlp=projection_mlp,
421
+ freeze_additional_model=True,
422
  device=device
423
+ )
424
+
425
  else:
426
+ conv_weights = [
427
+ v for k, v in state_dict.items()
428
+ if k.startswith("temporal_cnn.convs") and k.endswith("weight")
429
+ ]
430
+
431
+ cnn_num_filters = conv_weights[0].shape[0]
432
+ cnn_kernel_sizes = tuple(w.shape[2] for w in conv_weights)
433
+ cnn_dropout = 0.3
434
+ projection_mlp = ProjectionMLP(
435
+ input_size=input_dim,
436
+ hidden_size=hidden_dim,
437
+ num_labels=num_labels
438
+ )
439
+
440
  model = ConcatModelWithRationale(
441
  hatebert_model=hatebert_model,
442
  additional_model=rationale_model,
443
  projection_mlp=projection_mlp,
444
  hidden_size=H,
445
+ freeze_additional_model=True,
446
  cnn_num_filters=cnn_num_filters,
447
  cnn_kernel_sizes=cnn_kernel_sizes,
448
  cnn_dropout=cnn_dropout
449
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
+ model.load_state_dict(state_dict, strict=True)
 
452
 
453
+ model.eval()
 
 
 
 
454
 
455
+ config = {
456
+ "max_length": 128
457
+ }
 
 
 
 
 
 
458
 
459
+ return model, tokenizer_hatebert, tokenizer_rationale, config, device
 
460
 
461
+ def predict_text(
462
+ text,
463
+ rationale,
464
+ model,
465
+ tokenizer_hatebert,
466
+ tokenizer_rationale,
467
+ device="cpu",
468
+ max_length=128,
469
+ model_type="altered"
470
+ ):
471
 
 
 
 
472
  model.eval()
473
+
474
+ main_inputs = tokenizer_hatebert(
 
475
  text,
476
  max_length=max_length,
477
+ padding="max_length",
478
  truncation=True,
479
+ return_tensors="pt"
480
  )
481
+
482
+ rationale_inputs = tokenizer_rationale(
483
+ rationale if rationale else text,
484
  max_length=max_length,
485
+ padding="max_length",
486
  truncation=True,
487
+ return_tensors="pt"
488
  )
489
+
490
+ input_ids = main_inputs["input_ids"].to(device)
491
+ attention_mask = main_inputs["attention_mask"].to(device)
492
+
493
+ add_input_ids = rationale_inputs["input_ids"].to(device)
494
+ add_attention_mask = rationale_inputs["attention_mask"].to(device)
495
+
496
+ tokens = tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
497
  with torch.no_grad():
498
+
499
  if model_type.lower() == "base":
500
  logits = model(
501
+ input_ids,
502
+ attention_mask,
503
+ add_input_ids,
504
  add_attention_mask
505
  )
506
+ rationale_scores = None
507
  else:
508
+ outputs = model(
509
+ input_ids,
510
+ attention_mask,
511
+ add_input_ids,
512
  add_attention_mask
513
  )
514
+
515
+ if isinstance(outputs, tuple) and len(outputs) == 4:
516
+ logits, rationale_probs, _, _ = outputs
517
+ rationale_scores = rationale_probs[0].cpu().numpy()
518
+ else:
519
+ raise ValueError(f"Unexpected number of outputs from model: {len(outputs)}")
520
+
521
+ rationale_scores = rationale_probs[0].cpu().numpy()
522
+
523
+ probs = F.softmax(logits, dim=1)
524
+
525
+ if torch.isnan(probs).any() or torch.isinf(probs).any():
526
+ probs = torch.ones_like(logits) / logits.size(1)
527
+
528
  prediction = logits.argmax(dim=1).item()
529
  confidence = probs[0, prediction].item()
 
530
  return {
531
+ "prediction": prediction,
532
+ "confidence": confidence,
533
+ "probabilities": probs[0].cpu().numpy(),
534
+ "tokens": tokens,
535
+ "rationale_scores": rationale_scores
536
  }
 
 
 
 
 
 
 
 
537
 
538
  def predict_hatespeech_from_file(
539
  text_list,
 
556
 
557
  process = psutil.Process(os.getpid())
558
 
 
559
  if torch.cuda.is_available():
560
  torch.cuda.synchronize()
561
 
562
+ # warmup
563
  with torch.no_grad():
564
  _ = predict_text(
565
  text=text_list[0],
 
575
  if torch.cuda.is_available():
576
  torch.cuda.synchronize()
577
 
 
578
  start_time = time()
579
 
580
  for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
581
+
582
  result = predict_text(
583
  text=text,
584
  rationale=rationale,
 
593
  predictions.append(result['prediction'])
594
  all_probs.append(result['probabilities'])
595
 
 
596
  if idx % 10 == 0 or idx == len(text_list) - 1:
597
  cpu_percent_list.append(process.cpu_percent())
598
  memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
599
 
 
600
  if torch.cuda.is_available():
601
  torch.cuda.synchronize()
602
 
603
+ runtime = time() - start_time
 
604
 
605
  print(f"Inference completed for {type(model).__name__}")
606
  print(f"Total runtime: {runtime:.4f} seconds")
607
 
 
608
  all_probs = np.array(all_probs)
609
 
 
 
 
 
610
  f1 = f1_score(true_label, predictions, zero_division=0)
611
  accuracy = accuracy_score(true_label, predictions)
612
  precision = precision_score(true_label, predictions, zero_division=0)
 
619
  peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
620
 
621
  return {
622
+ 'model_name': type(model).__name__,
623
  'f1_score': f1,
624
  'accuracy': accuracy,
625
  'precision': precision,
 
632
  'runtime': runtime,
633
  'all_probabilities': all_probs.tolist()
634
  }
635
+
 
636
  def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
637
+
638
+ return predict_text(
 
 
 
639
  text=text,
640
  rationale=rationale,
641
  model=model,
 
644
  device=device,
645
  max_length=config.get('max_length', 128),
646
  model_type=model_type
647
+ )