jl commited on
Commit
b1587d0
·
1 Parent(s): 69200a8

update: new models and fit new models to app

Browse files
notebooks/Altered_SHIELD_Model.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/Reddit_Base_SHIELD_Model.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
notebooks/combined-baseline (1).ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/full-proposed-model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/app.py CHANGED
@@ -66,8 +66,8 @@ st.markdown('<div class="sub-header">Comparing Base vs Enhanced models with expl
66
  # Load both models with spinner
67
  with st.spinner('🔄 Loading models... This may take a moment on first run.'):
68
  try:
69
- # base_model, base_tokenizer_hatebert, base_tokenizer_rationale, base_config, base_device = load_cached_model("base")
70
- # enhanced_model, enhanced_tokenizer_hatebert, enhanced_tokenizer_rationale, enhanced_config, enhanced_device = load_cached_model("altered")
71
  st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
72
  except Exception as e:
73
  st.error(f"❌ Error loading models: {str(e)}")
@@ -155,23 +155,32 @@ classify_button = st.button("🔍 Analyze Text", type="primary", use_container_w
155
  if classify_button:
156
  if user_input and user_input.strip():
157
  with st.spinner('🔄 Analyzing text...'):
158
- # Get prediction
159
- # result = predict_hatespeech(
160
- # text=user_input,
161
- # rationale=optional_rationale if optional_rationale else None,
162
- # model=model,
163
- # tokenizer_hatebert=tokenizer_hatebert,
164
- # tokenizer_rationale=tokenizer_rationale,
165
- # config=config,
166
- # device=device
167
- # )
168
  # Run both models
169
- base_start = time.time()
170
- base_model_result = predict_text_mock(user_input)
171
- base_end = time.time()
172
  enhanced_start = time.time()
173
- enhanced_model_result = predict_text_mock(user_input)
 
 
 
 
 
 
 
 
 
174
  enhanced_end = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # Extract results for both models
177
  base_prediction = base_model_result['prediction']
@@ -359,9 +368,28 @@ if classify_button:
359
  # Run both models on the file
360
  # base_result = predict_hatespeech_from_file(...) # Base model
361
  # enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
362
- base_result = predict_hatespeech_from_file_mock()
363
- enhanced_result = predict_hatespeech_from_file_mock()
364
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  st.success("✅ File analysis complete for both models!")
366
  st.divider()
367
  st.header("📊 Analysis Results - Model Comparison")
 
66
  # Load both models with spinner
67
  with st.spinner('🔄 Loading models... This may take a moment on first run.'):
68
  try:
69
+ base_model, base_tokenizer_hatebert, base_tokenizer_rationale, base_config, base_device = load_cached_model("base")
70
+ enhanced_model, enhanced_tokenizer_hatebert, enhanced_tokenizer_rationale, enhanced_config, enhanced_device = load_cached_model("altered")
71
  st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
72
  except Exception as e:
73
  st.error(f"❌ Error loading models: {str(e)}")
 
155
  if classify_button:
156
  if user_input and user_input.strip():
157
  with st.spinner('🔄 Analyzing text...'):
 
 
 
 
 
 
 
 
 
 
158
  # Run both models
 
 
 
159
  enhanced_start = time.time()
160
+ enhanced_model_result = predict_hatespeech(
161
+ text=user_input,
162
+ rationale=optional_rationale if optional_rationale else None,
163
+ model=enhanced_model,
164
+ tokenizer_hatebert=enhanced_tokenizer_hatebert,
165
+ tokenizer_rationale=enhanced_tokenizer_rationale,
166
+ config=enhanced_config,
167
+ device=enhanced_device,
168
+ model_type="altered"
169
+ )
170
  enhanced_end = time.time()
171
+
172
+ base_start = time.time()
173
+ base_model_result = predict_hatespeech(
174
+ text=user_input,
175
+ rationale=optional_rationale if optional_rationale else None,
176
+ model=base_model,
177
+ tokenizer_hatebert=base_tokenizer_hatebert,
178
+ tokenizer_rationale=base_tokenizer_rationale,
179
+ config=base_config,
180
+ device=base_device,
181
+ model_type="base"
182
+ )
183
+ base_end = time.time()
184
 
185
  # Extract results for both models
186
  base_prediction = base_model_result['prediction']
 
368
  # Run both models on the file
369
  # base_result = predict_hatespeech_from_file(...) # Base model
370
  # enhanced_result = predict_hatespeech_from_file(...) # Enhanced model
371
+ enhanced_result = predict_hatespeech_from_file(
372
+ text_list=file_content['text'].tolist(),
373
+ rationale_list=file_content['CF_Rationales'].tolist(),
374
+ true_label=file_content['label'].tolist(),
375
+ model=enhanced_model,
376
+ tokenizer_hatebert=enhanced_tokenizer_hatebert,
377
+ tokenizer_rationale=enhanced_tokenizer_rationale,
378
+ config=enhanced_config,
379
+ device=enhanced_device,
380
+ model_type="altered"
381
+ )
382
+ base_result = predict_hatespeech_from_file(
383
+ text_list=file_content['text'].tolist(),
384
+ rationale_list=file_content['CF_Rationales'].tolist(),
385
+ true_label=file_content['label'].tolist(),
386
+ model=base_model,
387
+ tokenizer_hatebert=base_tokenizer_hatebert,
388
+ tokenizer_rationale=base_tokenizer_rationale,
389
+ config=base_config,
390
+ device=base_device,
391
+ model_type="base"
392
+ )
393
  st.success("✅ File analysis complete for both models!")
394
  st.divider()
395
  st.header("📊 Analysis Results - Model Comparison")
src/hatespeech_model.py CHANGED
@@ -1,5 +1,7 @@
1
  from huggingface_hub import hf_hub_download
2
  import torch
 
 
3
  import torch.nn as nn
4
  import json
5
  from transformers import AutoModel, AutoTokenizer
@@ -10,61 +12,62 @@ import os
10
 
11
  # Model Architecture Classes
12
  class TemporalCNN(nn.Module):
13
- def __init__(self, input_dim=768, num_filters=128, kernel_sizes=(2,3,4,5,6,7), dropout=0.3):
14
  super().__init__()
 
 
15
  self.convs = nn.ModuleList([
16
- nn.Conv1d(input_dim, num_filters, k) for k in kernel_sizes
 
17
  ])
18
  self.dropout = nn.Dropout(dropout)
19
- # Output size is num_filters * num_kernels * 2 (max + mean pooling)
20
- self.output_size = num_filters * len(kernel_sizes) * 2
21
-
22
- def forward(self, x, mask=None):
23
- x = x.transpose(1, 2) # (B, H, L)
24
- conv_outs = []
25
- for conv in self.convs:
26
- c = torch.relu(conv(x)) # (B, num_filters, L')
27
- # Both max and mean pooling
28
- max_pool = torch.max(c, dim=2)[0] # (B, num_filters)
29
- mean_pool = torch.mean(c, dim=2) # (B, num_filters)
30
- conv_outs.append(max_pool)
31
- conv_outs.append(mean_pool)
32
- out = torch.cat(conv_outs, dim=1) # (B, num_filters * len(kernel_sizes) * 2)
33
- out = self.dropout(out)
34
- return out
 
 
 
 
35
 
36
  class MultiScaleAttentionCNN(nn.Module):
37
- def __init__(self, hidden_size=768, num_filters=128, kernel_sizes=(2,3,4,5,6,7), dropout=0.3):
38
- super().__init__()
39
- # Convolution layers
40
- self.convs = nn.ModuleList([
41
- nn.Conv1d(hidden_size, num_filters, k) for k in kernel_sizes
42
- ])
43
- # Attention layers - output 1 value per filter for attention weighting
44
- self.attn = nn.ModuleList([
45
- nn.Linear(num_filters, 1) for _ in kernel_sizes
46
- ])
47
- self.dropout = nn.Dropout(dropout)
48
- self.output_size = num_filters * len(kernel_sizes)
49
-
50
- def forward(self, x, mask=None):
51
- x = x.transpose(1, 2) # (B, H, L)
52
- conv_outs = []
53
- for conv, attn in zip(self.convs, self.attn):
54
- c = torch.relu(conv(x)) # (B, num_filters, L')
55
- c_t = c.transpose(1, 2) # (B, L', num_filters)
56
- # Apply attention to get weights
57
- w = attn(c_t) # (B, L', 1)
58
- w = torch.softmax(w, dim=1) # attention weights
59
- # Weighted sum pooling
60
- pooled = (c_t * w).sum(dim=1) # (B, num_filters)
61
- conv_outs.append(pooled)
62
- out = torch.cat(conv_outs, dim=1) # (B, num_filters * len(kernel_sizes))
63
- out = self.dropout(out)
64
- return out
65
 
66
  class ProjectionMLP(nn.Module):
67
- def __init__(self, input_size, hidden_size, num_labels, dropout=0.3):
68
  super().__init__()
69
  self.layers = nn.Sequential(
70
  nn.Linear(input_size, hidden_size),
@@ -75,109 +78,95 @@ class ProjectionMLP(nn.Module):
75
  def forward(self, x):
76
  return self.layers(x)
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class BaseShield(nn.Module):
79
  """
80
  Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
81
  """
82
- def __init__(self, hatebert_model, additional_model, projection_mlp, hidden_size=768,
83
  freeze_additional_model=True):
84
  super().__init__()
85
  self.hatebert_model = hatebert_model
86
  self.additional_model = additional_model
87
  self.projection_mlp = projection_mlp
88
- self.hidden_size = hidden_size
89
 
90
  if freeze_additional_model:
91
  for param in self.additional_model.parameters():
92
  param.requires_grad = False
93
 
94
- def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask,
95
- return_attentions=False):
96
- # Main text through HateBERT - get CLS token only
97
- hatebert_out = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask,
98
- output_attentions=return_attentions, return_dict=True)
99
- hatebert_cls = hatebert_out.last_hidden_state[:, 0, :] # (B, 768)
100
-
101
- # Rationale text through frozen BERT - get CLS token only
102
- with torch.no_grad():
103
- add_out = self.additional_model(input_ids=additional_input_ids,
104
- attention_mask=additional_attention_mask,
105
- return_dict=True)
106
- rationale_cls = add_out.last_hidden_state[:, 0, :] # (B, 768)
107
-
108
- # Concatenate CLS embeddings: (B, 1536)
109
- concat_emb = torch.cat((hatebert_cls, rationale_cls), dim=1)
110
-
111
- # Classification
112
- logits = self.projection_mlp(concat_emb)
113
-
114
- # Return dummy rationale_probs and selector_logits for compatibility with app
115
- batch_size = input_ids.size(0)
116
- seq_len = input_ids.size(1)
117
- dummy_rationale_probs = torch.zeros(batch_size, seq_len, device=input_ids.device)
118
- dummy_selector_logits = torch.zeros(batch_size, seq_len, device=input_ids.device)
119
-
120
- attns = hatebert_out.attentions if (return_attentions and hasattr(hatebert_out, "attentions")) else None
121
- return logits, dummy_rationale_probs, dummy_selector_logits, attns
122
 
 
 
123
 
124
- class ConcatModelWithRationale(nn.Module):
125
- def __init__(self, hatebert_model, additional_model, projection_mlp, hidden_size=768,
126
- gumbel_temp=0.5, freeze_additional_model=True, cnn_num_filters=128,
127
- cnn_kernel_sizes=(2,3,4), cnn_dropout=0.3):
128
  super().__init__()
129
  self.hatebert_model = hatebert_model
130
  self.additional_model = additional_model
 
 
 
131
  self.projection_mlp = projection_mlp
132
- self.gumbel_temp = gumbel_temp
133
- self.hidden_size = hidden_size
134
-
135
  if freeze_additional_model:
136
- for param in self.additional_model.parameters():
137
- param.requires_grad = False
138
-
139
- self.selector = nn.Linear(hidden_size, 1)
140
- self.temporal_cnn = TemporalCNN(input_dim=hidden_size, num_filters=cnn_num_filters,
141
- kernel_sizes=cnn_kernel_sizes, dropout=cnn_dropout)
142
- self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
143
- self.msa_cnn = MultiScaleAttentionCNN(hidden_size=hidden_size, num_filters=cnn_num_filters,
144
- kernel_sizes=cnn_kernel_sizes, dropout=cnn_dropout)
145
- self.msa_out_dim = self.msa_cnn.output_size
146
-
147
- def gumbel_sigmoid_sample(self, logits):
148
- noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
149
- y = logits + noise
150
- return torch.sigmoid(y / self.gumbel_temp)
151
-
152
- def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask,
153
- return_attentions=False):
154
- hatebert_out = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask,
155
- output_attentions=return_attentions, return_dict=True)
156
- hatebert_emb = hatebert_out.last_hidden_state
157
- cls_emb = hatebert_emb[:, 0, :]
158
-
159
- with torch.no_grad():
160
- add_out = self.additional_model(input_ids=additional_input_ids,
161
- attention_mask=additional_attention_mask,
162
- return_dict=True)
163
- rationale_emb = add_out.last_hidden_state
164
-
165
- selector_logits = self.selector(hatebert_emb).squeeze(-1)
166
- rationale_probs = self.gumbel_sigmoid_sample(selector_logits)
167
- rationale_probs = rationale_probs * attention_mask.float().to(rationale_probs.device)
168
 
169
- masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
170
- denom = rationale_probs.sum(1).unsqueeze(-1).clamp_min(1e-6)
171
- pooled_rationale = masked_hidden.sum(1) / denom
172
 
173
- temporal_features = self.temporal_cnn(hatebert_emb, attention_mask)
174
- rationale_features = self.msa_cnn(rationale_emb, additional_attention_mask)
175
 
176
- concat_emb = torch.cat((cls_emb, temporal_features, rationale_features, pooled_rationale), dim=1)
177
- logits = self.projection_mlp(concat_emb)
 
178
 
179
- attns = hatebert_out.attentions if (return_attentions and hasattr(hatebert_out, "attentions")) else None
180
- return logits, rationale_probs, selector_logits, attns
 
 
181
 
182
  def load_model_from_hf(model_type="altered"):
183
  """
@@ -187,8 +176,9 @@ def load_model_from_hf(model_type="altered"):
187
  model_type: Either "altered" or "base" to choose which model to load
188
  """
189
 
 
190
  repo_id = "seffyehl/BetterShield"
191
- repo_type = "e5912f6e8c34a10629cfd5a7971ac71ac76d0e9d"
192
 
193
  # Choose model and config files based on model_type
194
  if model_type.lower() == "altered":
@@ -203,14 +193,14 @@ def load_model_from_hf(model_type="altered"):
203
  # Download files
204
  model_path = hf_hub_download(
205
  repo_id=repo_id,
206
- revision=repo_type,
207
  filename=model_filename
208
  )
209
 
210
  config_path = hf_hub_download(
211
  repo_id=repo_id,
212
  filename=config_filename,
213
- revision=repo_type
214
  )
215
 
216
  # Load config
@@ -246,48 +236,37 @@ def load_model_from_hf(model_type="altered"):
246
  # The saved model uses 512, not what's in projection_config
247
  adapter_dim = 512 # hardcoded to match saved weights
248
  projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim,
249
- num_labels=2, dropout=0.0)
250
 
251
  model = BaseShield(
252
  hatebert_model=hatebert_model,
253
  additional_model=rationale_model,
254
  projection_mlp=projection_mlp,
255
- hidden_size=H,
256
- freeze_additional_model=True
257
- )
258
  else:
259
- # Altered Shield: Complex model with CNN and attention
260
- cnn_num_filters = model_config.get('cnn_num_filters', 128)
261
- # Use extended kernel sizes to match saved model
262
- cnn_kernel_sizes = (2, 3, 4, 5, 6, 7)
263
- adapter_dim = model_config.get('adapter_dim', 128)
264
- cnn_dropout = model_config.get('cnn_dropout', 0.3)
265
-
266
- # Calculate dimensions
267
- # TemporalCNN: num_filters * len(kernel_sizes) * 2 (max + mean pooling)
268
- temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
269
- # MultiScaleAttentionCNN: num_filters * len(kernel_sizes)
270
- msa_out_dim = cnn_num_filters * len(cnn_kernel_sizes)
271
- # Total: CLS (768) + TemporalCNN + MSA + pooled_rationale (768)
272
- proj_input_dim = H + temporal_out_dim + msa_out_dim + H
273
- projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim,
274
- num_labels=2, dropout=0.0)
275
-
276
- model = ConcatModelWithRationale(
277
- hatebert_model=hatebert_model,
278
- additional_model=rationale_model,
279
- projection_mlp=projection_mlp,
280
- hidden_size=H,
281
- freeze_additional_model=True,
282
- cnn_num_filters=cnn_num_filters,
283
- cnn_kernel_sizes=cnn_kernel_sizes,
284
- cnn_dropout=cnn_dropout
285
- )
286
-
287
- model.load_state_dict(checkpoint['model_state_dict'])
288
  model.eval()
289
-
290
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
291
  model = model.to(device)
292
 
293
  # Create a unified config dict with max_length at top level for compatibility
@@ -298,7 +277,7 @@ def load_model_from_hf(model_type="altered"):
298
  return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
299
 
300
  def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
301
- device='cpu', max_length=128):
302
  """
303
  Predict hate speech for a given text and rationale
304
 
@@ -310,6 +289,7 @@ def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale
310
  tokenizer_rationale: Rationale model tokenizer
311
  device: 'cpu' or 'cuda'
312
  max_length: Maximum sequence length
 
313
 
314
  Returns:
315
  prediction: 0 or 1
@@ -342,6 +322,28 @@ def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale
342
  add_attention_mask = inputs_rationale['attention_mask'].to(device)
343
 
344
  # Inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  with torch.no_grad():
346
  logits, rationale_probs, selector_logits, _ = model(
347
  input_ids,
@@ -363,7 +365,7 @@ def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale
363
  'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
364
  }
365
 
366
- def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, tokenizer_hatebert, tokenizer_rationale, config, device):
367
  """
368
  Predict hate speech for text read from a file
369
 
@@ -400,7 +402,8 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
400
  tokenizer_hatebert=tokenizer_hatebert,
401
  tokenizer_rationale=tokenizer_rationale,
402
  device=device,
403
- max_length=config.get('max_length', 128)
 
404
  )
405
  predictions.append(result['prediction'])
406
  # Log resource usage every 10th sample and at end to reduce overhead
@@ -436,7 +439,7 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
436
  }
437
 
438
 
439
- def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device):
440
  """
441
  Predict hate speech for given text
442
 
@@ -460,7 +463,8 @@ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rat
460
  tokenizer_hatebert=tokenizer_hatebert,
461
  tokenizer_rationale=tokenizer_rationale,
462
  device=device,
463
- max_length=config.get('max_length', 128)
 
464
  )
465
 
466
  return result
 
1
  from huggingface_hub import hf_hub_download
2
  import torch
3
+ from torch.cuda import device
4
+ from torch.nn import functional as F
5
  import torch.nn as nn
6
  import json
7
  from transformers import AutoModel, AutoTokenizer
 
12
 
13
  # Model Architecture Classes
14
  class TemporalCNN(nn.Module):
15
+ def __init__(self, hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4), dropout=0.1, dilation_base=2):
16
  super().__init__()
17
+ self.kernel_sizes = kernel_sizes
18
+ self.dilation_base = dilation_base
19
  self.convs = nn.ModuleList([
20
+ nn.Conv1d(hidden_size, num_filters, k, dilation=dilation_base ** i, padding=0)
21
+ for i, k in enumerate(kernel_sizes)
22
  ])
23
  self.dropout = nn.Dropout(dropout)
24
+ self.out_dim = num_filters * len(kernel_sizes)
25
+
26
+ def _causal_padding(self, x, kernel_size, dilation):
27
+ padding = (kernel_size - 1) * dilation
28
+ return F.pad(x, (padding, 0))
29
+
30
+ def forward(self, x, attention_mask):
31
+ mask = attention_mask.unsqueeze(-1)
32
+ x = x * mask
33
+ x = x.transpose(1, 2)
34
+ feats = []
35
+ for i, conv in enumerate(self.convs):
36
+ kernel_size = self.kernel_sizes[i]
37
+ dilation = self.dilation_base ** i
38
+ x_padded = self._causal_padding(x, kernel_size, dilation)
39
+ c = F.relu(conv(x_padded))
40
+ p = F.max_pool1d(c, kernel_size=c.size(2)).squeeze(2)
41
+ feats.append(p)
42
+ out = torch.cat(feats, dim=1)
43
+ return self.dropout(out)
44
 
45
  class MultiScaleAttentionCNN(nn.Module):
46
+ def __init__(self, hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4), dropout=0.3):
47
+ super().__init__()
48
+ self.convs = nn.ModuleList([
49
+ nn.Conv1d(hidden_size, num_filters, k) for k in kernel_sizes
50
+ ])
51
+ self.attention_fc = nn.Linear(num_filters, 1)
52
+ self.dropout = nn.Dropout(dropout)
53
+ self.out_dim = num_filters * len(kernel_sizes)
54
+
55
+ def forward(self, x, mask):
56
+ x = x.transpose(1, 2)
57
+ feats = []
58
+ for conv in self.convs:
59
+ h = F.relu(conv(x))
60
+ h = h.transpose(1, 2)
61
+ attn = self.attention_fc(h).squeeze(-1)
62
+ attn = attn.masked_fill(mask[:, :attn.size(1)] == 0, -1e9)
63
+ alpha = F.softmax(attn, dim=1)
64
+ pooled = torch.sum(h * alpha.unsqueeze(-1), dim=1)
65
+ feats.append(pooled)
66
+ out = torch.cat(feats, dim=1)
67
+ return self.dropout(out)
 
 
 
 
 
 
68
 
69
  class ProjectionMLP(nn.Module):
70
+ def __init__(self, input_size, hidden_size, num_labels):
71
  super().__init__()
72
  self.layers = nn.Sequential(
73
  nn.Linear(input_size, hidden_size),
 
78
  def forward(self, x):
79
  return self.layers(x)
80
 
81
+ class GumbelTokenSelector(nn.Module):
82
+ def __init__(self, hidden_size, tau=1.0):
83
+ super().__init__()
84
+ self.tau = tau
85
+ self.proj = nn.Linear(hidden_size * 2, 1)
86
+
87
+ def forward(self, token_embeddings, cls_embedding, training=True):
88
+ B, L, H = token_embeddings.size()
89
+ cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)
90
+ x = torch.cat([token_embeddings, cls_exp], dim=-1)
91
+ logits = self.proj(x).squeeze(-1)
92
+
93
+ if training:
94
+ probs = F.gumbel_softmax(
95
+ torch.stack([logits, torch.zeros_like(logits)], dim=-1),
96
+ tau=self.tau,
97
+ hard=False
98
+ )[..., 0]
99
+ else:
100
+ probs = torch.sigmoid(logits)
101
+ return probs, logits
102
+
103
  class BaseShield(nn.Module):
104
  """
105
  Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
106
  """
107
+ def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu',
108
  freeze_additional_model=True):
109
  super().__init__()
110
  self.hatebert_model = hatebert_model
111
  self.additional_model = additional_model
112
  self.projection_mlp = projection_mlp
113
+ self.device = device
114
 
115
  if freeze_additional_model:
116
  for param in self.additional_model.parameters():
117
  param.requires_grad = False
118
 
119
+ def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
120
+ hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
121
+ hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
122
+ hatebert_embeddings = torch.nn.LayerNorm(hatebert_embeddings.size()[1:]).to(self.device)(hatebert_embeddings.to(self.device)).to(self.device)
123
+
124
+ additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
125
+ additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
126
+ additional_embeddings = torch.nn.LayerNorm(additional_embeddings.size()[1:]).to(self.device)(additional_embeddings.to(self.device)).to(self.device)
127
+
128
+ concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1).to(self.device)
129
+ projected_embeddings = self.projection_mlp(concatenated_embeddings).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # Return 4 values to match ConcatModel interface (rationale_probs, selector_logits, attentions are None)
132
+ return projected_embeddings
133
 
134
+ class ConcatModel(nn.Module):
135
+ def __init__(self, hatebert_model, additional_model, temporal_cnn, msa_cnn, selector, projection_mlp, freeze_additional_model=True, freeze_hatebert=True):
 
 
136
  super().__init__()
137
  self.hatebert_model = hatebert_model
138
  self.additional_model = additional_model
139
+ self.temporal_cnn = temporal_cnn
140
+ self.msa_cnn = msa_cnn
141
+ self.selector = selector
142
  self.projection_mlp = projection_mlp
143
+
 
 
144
  if freeze_additional_model:
145
+ for p in self.additional_model.parameters():
146
+ p.requires_grad = False
147
+ if freeze_hatebert:
148
+ for p in self.hatebert_model.parameters():
149
+ p.requires_grad = False
150
+
151
+ def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
152
+ hate_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
153
+ seq_emb = hate_outputs.last_hidden_state
154
+ cls_emb = seq_emb[:, 0, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ token_probs, token_logits = self.selector(seq_emb, cls_emb, self.training)
157
+ temporal_feat = self.temporal_cnn(seq_emb, attention_mask)
 
158
 
159
+ weights = token_probs.unsqueeze(-1)
160
+ H_r = (seq_emb * weights).sum(dim=1) / (weights.sum(dim=1) + 1e-6)
161
 
162
+ with torch.no_grad():
163
+ add_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
164
+ add_seq = add_outputs.last_hidden_state
165
 
166
+ msa_feat = self.msa_cnn(add_seq, additional_attention_mask)
167
+ concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)
168
+ logits = self.projection_mlp(concat)
169
+ return logits, token_probs, token_logits, hate_outputs.attentions if hasattr(hate_outputs, "attentions") else None
170
 
171
  def load_model_from_hf(model_type="altered"):
172
  """
 
176
  model_type: Either "altered" or "base" to choose which model to load
177
  """
178
 
179
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
180
  repo_id = "seffyehl/BetterShield"
181
+ # repo_type = "e5912f6e8c34a10629cfd5a7971ac71ac76d0e9d"
182
 
183
  # Choose model and config files based on model_type
184
  if model_type.lower() == "altered":
 
193
  # Download files
194
  model_path = hf_hub_download(
195
  repo_id=repo_id,
196
+ # revision=repo_type,
197
  filename=model_filename
198
  )
199
 
200
  config_path = hf_hub_download(
201
  repo_id=repo_id,
202
  filename=config_filename,
203
+ # revision=repo_type
204
  )
205
 
206
  # Load config
 
236
  # The saved model uses 512, not what's in projection_config
237
  adapter_dim = 512 # hardcoded to match saved weights
238
  projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim,
239
+ num_labels=2)
240
 
241
  model = BaseShield(
242
  hatebert_model=hatebert_model,
243
  additional_model=rationale_model,
244
  projection_mlp=projection_mlp,
245
+ freeze_additional_model=True,
246
+ device=device
247
+ ).to(device)
248
  else:
249
+ temporal_cnn = TemporalCNN(hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4)).to(device)
250
+ msa_cnn = MultiScaleAttentionCNN(hidden_size=768, num_filters=128, kernel_sizes=(2, 3, 4)).to(device)
251
+ selector = GumbelTokenSelector(hidden_size=768, tau=1.0).to(device)
252
+ projection_mlp = ProjectionMLP(input_size=temporal_cnn.out_dim + msa_cnn.out_dim + 768 * 2, hidden_size=512, num_labels=2).to(device)
253
+ model = ConcatModel(
254
+ hatebert_model=hatebert_model,
255
+ additional_model=rationale_model,
256
+ temporal_cnn=temporal_cnn,
257
+ msa_cnn=msa_cnn,
258
+ selector=selector,
259
+ projection_mlp=projection_mlp,
260
+ freeze_additional_model=True,
261
+ freeze_hatebert=True).to(device)
262
+
263
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
264
+ model.load_state_dict(checkpoint['model_state_dict'])
265
+ print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
266
+ print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
267
+ else:
268
+ model.load_state_dict(checkpoint)
 
 
 
 
 
 
 
 
 
269
  model.eval()
 
 
270
  model = model.to(device)
271
 
272
  # Create a unified config dict with max_length at top level for compatibility
 
277
  return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
278
 
279
  def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
280
+ device='cpu', max_length=128, model_type="altered"):
281
  """
282
  Predict hate speech for a given text and rationale
283
 
 
289
  tokenizer_rationale: Rationale model tokenizer
290
  device: 'cpu' or 'cuda'
291
  max_length: Maximum sequence length
292
+ model_type: Either "altered" or "base" to determine how to process inputs
293
 
294
  Returns:
295
  prediction: 0 or 1
 
322
  add_attention_mask = inputs_rationale['attention_mask'].to(device)
323
 
324
  # Inference
325
+ if model_type.lower() == "base":
326
+ with torch.no_grad():
327
+ logits = model(
328
+ input_ids,
329
+ attention_mask,
330
+ add_input_ids,
331
+ add_attention_mask
332
+ )
333
+
334
+ # Get probabilities
335
+ probs = torch.softmax(logits, dim=1)
336
+ prediction = logits.argmax(dim=1).item()
337
+ confidence = probs[0, prediction].item()
338
+
339
+ return {
340
+ 'prediction': prediction,
341
+ 'confidence': confidence,
342
+ 'probabilities': probs[0].cpu().numpy(),
343
+ 'rationale_scores': None, # Base model does not produce token-level rationale scores
344
+ 'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
345
+ }
346
+
347
  with torch.no_grad():
348
  logits, rationale_probs, selector_logits, _ = model(
349
  input_ids,
 
365
  'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
366
  }
367
 
368
+ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
369
  """
370
  Predict hate speech for text read from a file
371
 
 
402
  tokenizer_hatebert=tokenizer_hatebert,
403
  tokenizer_rationale=tokenizer_rationale,
404
  device=device,
405
+ max_length=config.get('max_length', 128),
406
+ model_type=model_type
407
  )
408
  predictions.append(result['prediction'])
409
  # Log resource usage every 10th sample and at end to reduce overhead
 
439
  }
440
 
441
 
442
+ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
443
  """
444
  Predict hate speech for given text
445
 
 
463
  tokenizer_hatebert=tokenizer_hatebert,
464
  tokenizer_rationale=tokenizer_rationale,
465
  device=device,
466
+ max_length=config.get('max_length', 128),
467
+ model_type=model_type
468
  )
469
 
470
  return result