Spaces:
Sleeping
Sleeping
fix - model v4 models used
Browse files- src/app.py +35 -17
- src/hatespeech_model.py +357 -316
src/app.py
CHANGED
|
@@ -21,8 +21,14 @@ st.set_page_config(
|
|
| 21 |
# Cached model loading function
|
| 22 |
@st.cache_resource
|
| 23 |
def load_cached_model(model_type="altered"):
|
| 24 |
-
|
| 25 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Custom CSS
|
| 28 |
st.markdown("""
|
|
@@ -68,8 +74,21 @@ st.markdown('<div class="sub-header">Comparing Base vs Enhanced models with expl
|
|
| 68 |
# Load both models with spinner
|
| 69 |
with st.spinner('🔄 Loading models... This may take a moment on first run.'):
|
| 70 |
try:
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
|
| 74 |
except Exception as e:
|
| 75 |
st.error(f"❌ Error loading models: {str(e)}")
|
|
@@ -320,30 +339,31 @@ if classify_button:
|
|
| 320 |
|
| 321 |
# Filter out special tokens and create visualization
|
| 322 |
token_importance = []
|
| 323 |
-
html_output = "<div style='font-size: 16px; line-height: 2.2; padding: 15px; background-color: #
|
| 324 |
|
| 325 |
for token, score in zip(enhanced_tokens, enhanced_rationale_scores):
|
| 326 |
if token not in ['[CLS]', '[SEP]', '[PAD]']:
|
| 327 |
# Clean token
|
| 328 |
display_token = token.replace('##', '')
|
| 329 |
token_importance.append({'Token': display_token, 'Importance': score})
|
| 330 |
-
|
| 331 |
# Color intensity based on score
|
| 332 |
alpha = min(score * 1.5, 1.0) # Scale up visibility
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
|
| 338 |
html_output += f"<span style='background-color: {color}; padding: 3px 6px; margin: 1px; border-radius: 4px; display: inline-block;'>{display_token}</span> "
|
| 339 |
|
| 340 |
html_output += "</div>"
|
| 341 |
st.markdown(html_output, unsafe_allow_html=True)
|
| 342 |
-
|
| 343 |
-
if enhanced_prediction == 1:
|
| 344 |
-
st.caption("🔴 Darker red = Higher importance for hate speech detection")
|
| 345 |
-
else:
|
| 346 |
-
st.caption("🟢 Darker green = Higher importance for non-hate speech classification")
|
| 347 |
|
| 348 |
# Top important tokens
|
| 349 |
st.markdown("**📋 Top Important Tokens**")
|
|
@@ -459,7 +479,6 @@ if classify_button:
|
|
| 459 |
with enhanced_file_col:
|
| 460 |
st.subheader("🟢 Enhanced Shield Results")
|
| 461 |
|
| 462 |
-
# Performance Metrics
|
| 463 |
st.markdown("**📈 Classification Metrics**")
|
| 464 |
enh_fm1, enh_fm2 = st.columns(2)
|
| 465 |
with enh_fm1:
|
|
@@ -485,7 +504,6 @@ if classify_button:
|
|
| 485 |
fig_enhanced_cm.update_layout(height=300)
|
| 486 |
st.plotly_chart(fig_enhanced_cm, use_container_width=True)
|
| 487 |
|
| 488 |
-
# Resource Usage
|
| 489 |
st.markdown("**⚙️ Resource Usage**")
|
| 490 |
enh_cpu_col, enh_mem_col = st.columns(2)
|
| 491 |
with enh_cpu_col:
|
|
|
|
| 21 |
# Cached model loading function
|
| 22 |
@st.cache_resource
|
| 23 |
def load_cached_model(model_type="altered"):
|
| 24 |
+
model, tokenizer_hatebert, tokenizer_rationale, config, device = load_model_from_hf(model_type=model_type)
|
| 25 |
+
return {
|
| 26 |
+
"model": model,
|
| 27 |
+
"tokenizer_hatebert": tokenizer_hatebert,
|
| 28 |
+
"tokenizer_rationale": tokenizer_rationale,
|
| 29 |
+
"config": config,
|
| 30 |
+
"device": device
|
| 31 |
+
}
|
| 32 |
|
| 33 |
# Custom CSS
|
| 34 |
st.markdown("""
|
|
|
|
| 74 |
# Load both models with spinner
|
| 75 |
with st.spinner('🔄 Loading models... This may take a moment on first run.'):
|
| 76 |
try:
|
| 77 |
+
base_data = load_cached_model("base")
|
| 78 |
+
enhanced_data = load_cached_model("altered")
|
| 79 |
+
|
| 80 |
+
base_model = base_data["model"]
|
| 81 |
+
base_tokenizer_hatebert = base_data["tokenizer_hatebert"]
|
| 82 |
+
base_tokenizer_rationale = base_data["tokenizer_rationale"]
|
| 83 |
+
base_config = base_data["config"]
|
| 84 |
+
base_device = base_data["device"]
|
| 85 |
+
|
| 86 |
+
enhanced_model = enhanced_data["model"]
|
| 87 |
+
enhanced_tokenizer_hatebert = enhanced_data["tokenizer_hatebert"]
|
| 88 |
+
enhanced_tokenizer_rationale = enhanced_data["tokenizer_rationale"]
|
| 89 |
+
enhanced_config = enhanced_data["config"]
|
| 90 |
+
enhanced_device = enhanced_data["device"]
|
| 91 |
+
|
| 92 |
st.success('✅ Base Shield and Enhanced Shield models loaded successfully!')
|
| 93 |
except Exception as e:
|
| 94 |
st.error(f"❌ Error loading models: {str(e)}")
|
|
|
|
| 339 |
|
| 340 |
# Filter out special tokens and create visualization
|
| 341 |
token_importance = []
|
| 342 |
+
html_output = "<div style='font-size: 16px; line-height: 2.2; padding: 15px; background-color: #0E1117; border-radius: 10px;'>"
|
| 343 |
|
| 344 |
for token, score in zip(enhanced_tokens, enhanced_rationale_scores):
|
| 345 |
if token not in ['[CLS]', '[SEP]', '[PAD]']:
|
| 346 |
# Clean token
|
| 347 |
display_token = token.replace('##', '')
|
| 348 |
token_importance.append({'Token': display_token, 'Importance': score})
|
| 349 |
+
thresholds = [0.25, 0.5, 0.75]
|
| 350 |
# Color intensity based on score
|
| 351 |
alpha = min(score * 1.5, 1.0) # Scale up visibility
|
| 352 |
+
color = f"rgba(239, 83, 80, {alpha:.2f})"
|
| 353 |
+
# if score > 0.5:
|
| 354 |
+
# color = f"rgba(239, 83, 80, {alpha:.2f})"
|
| 355 |
+
# else:
|
| 356 |
+
# color = f"rgba(242, 155, 5, {alpha:.2f})"
|
| 357 |
+
# if enhanced_prediction == 1: # Hate speech
|
| 358 |
+
# color = f"rgba(239, 83, 80, {alpha:.2f})"
|
| 359 |
+
# else: # Not hate speech
|
| 360 |
+
# color = f"rgba(102, 187, 106, {alpha:.2f})"
|
| 361 |
|
| 362 |
html_output += f"<span style='background-color: {color}; padding: 3px 6px; margin: 1px; border-radius: 4px; display: inline-block;'>{display_token}</span> "
|
| 363 |
|
| 364 |
html_output += "</div>"
|
| 365 |
st.markdown(html_output, unsafe_allow_html=True)
|
| 366 |
+
st.caption("🔴 Darker red = More influence on hate speech detection.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
| 368 |
# Top important tokens
|
| 369 |
st.markdown("**📋 Top Important Tokens**")
|
|
|
|
| 479 |
with enhanced_file_col:
|
| 480 |
st.subheader("🟢 Enhanced Shield Results")
|
| 481 |
|
|
|
|
| 482 |
st.markdown("**📈 Classification Metrics**")
|
| 483 |
enh_fm1, enh_fm2 = st.columns(2)
|
| 484 |
with enh_fm1:
|
|
|
|
| 504 |
fig_enhanced_cm.update_layout(height=300)
|
| 505 |
st.plotly_chart(fig_enhanced_cm, use_container_width=True)
|
| 506 |
|
|
|
|
| 507 |
st.markdown("**⚙️ Resource Usage**")
|
| 508 |
enh_cpu_col, enh_mem_col = st.columns(2)
|
| 509 |
with enh_cpu_col:
|
src/hatespeech_model.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from huggingface_hub import hf_hub_download
|
| 2 |
import torch
|
| 3 |
-
from torch.cuda import device
|
| 4 |
from torch.nn import functional as F
|
| 5 |
import torch.nn as nn
|
| 6 |
import json
|
|
@@ -9,133 +8,183 @@ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_sc
|
|
| 9 |
from time import time
|
| 10 |
import psutil
|
| 11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
| 14 |
class TemporalCNN(nn.Module):
|
| 15 |
-
|
| 16 |
-
Temporal CNN applied across the sequence (time) dimension.
|
| 17 |
-
Input: sequence_embeddings (B, L, H), attention_mask (B, L)
|
| 18 |
-
Output: pooled vector (B, output_dim) where output_dim = num_filters * len(kernel_sizes) * 2
|
| 19 |
-
(we concatenate max-pooled and mean-pooled features for each kernel size)
|
| 20 |
-
"""
|
| 21 |
-
def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2, 3, 4), dropout=0.3):
|
| 22 |
super().__init__()
|
| 23 |
self.input_dim = input_dim
|
| 24 |
self.num_filters = num_filters
|
| 25 |
self.kernel_sizes = kernel_sizes
|
| 26 |
-
|
| 27 |
-
# Convs expect (B, C_in, L) where C_in = input_dim
|
| 28 |
self.convs = nn.ModuleList([
|
| 29 |
-
nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k
|
| 30 |
for k in kernel_sizes
|
| 31 |
])
|
| 32 |
self.dropout = nn.Dropout(dropout)
|
| 33 |
|
| 34 |
def forward(self, sequence_embeddings, attention_mask=None):
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
attention_mask: (B, L) with 1 for valid tokens, 0 for padding
|
| 38 |
-
returns: (B, num_filters * len(kernel_sizes) * 2) # max + mean pooled per conv
|
| 39 |
-
"""
|
| 40 |
-
# transpose to (B, H, L)
|
| 41 |
-
x = sequence_embeddings.transpose(1, 2).contiguous() # (B, H, L)
|
| 42 |
-
|
| 43 |
pooled_outputs = []
|
| 44 |
for conv in self.convs:
|
| 45 |
-
conv_out = conv(x)
|
| 46 |
conv_out = F.relu(conv_out)
|
| 47 |
L_out = conv_out.size(2)
|
| 48 |
-
|
| 49 |
if attention_mask is not None:
|
| 50 |
-
# resize mask to match L_out
|
| 51 |
mask = attention_mask.float()
|
| 52 |
if mask.size(1) != L_out:
|
| 53 |
mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
|
| 54 |
mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
|
| 55 |
-
|
| 56 |
-
# max pool with masking
|
| 57 |
neg_inf = torch.finfo(conv_out.dtype).min / 2
|
| 58 |
-
max_masked = torch.where(mask.bool(), conv_out, neg_inf
|
| 59 |
max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
|
| 60 |
-
|
| 61 |
-
# mean pool with masking
|
| 62 |
sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
|
| 63 |
denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
|
| 64 |
mean_pooled = sum_masked / denom # (B, num_filters)
|
| 65 |
else:
|
| 66 |
max_pooled = torch.max(conv_out, dim=2)[0]
|
| 67 |
mean_pooled = conv_out.mean(dim=2)
|
| 68 |
-
|
| 69 |
pooled_outputs.append(max_pooled)
|
| 70 |
pooled_outputs.append(mean_pooled)
|
| 71 |
-
|
| 72 |
-
out = torch.cat(pooled_outputs, dim=1)
|
| 73 |
out = self.dropout(out)
|
| 74 |
return out
|
| 75 |
|
| 76 |
|
| 77 |
class MultiScaleAttentionCNN(nn.Module):
|
| 78 |
-
def __init__(self, hidden_size=768, num_filters=
|
| 79 |
super().__init__()
|
| 80 |
|
| 81 |
self.hidden_size = hidden_size
|
| 82 |
self.kernel_sizes = kernel_sizes
|
| 83 |
|
| 84 |
self.convs = nn.ModuleList()
|
| 85 |
-
self.pads
|
| 86 |
|
| 87 |
for k in self.kernel_sizes:
|
| 88 |
-
pad_left
|
| 89 |
pad_right = k - 1 - pad_left
|
|
|
|
| 90 |
self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
|
| 94 |
self.output_size = num_filters * len(self.kernel_sizes)
|
| 95 |
self.dropout = nn.Dropout(dropout)
|
| 96 |
|
| 97 |
def forward(self, hidden_states, mask):
|
| 98 |
-
|
| 99 |
-
hidden_states: (B, L, H)
|
| 100 |
-
mask: (B, L)
|
| 101 |
-
"""
|
| 102 |
-
x = hidden_states.transpose(1, 2) # (B, H, L)
|
| 103 |
attn_mask = mask.unsqueeze(1).float()
|
| 104 |
|
| 105 |
conv_outs = []
|
| 106 |
|
| 107 |
for pad, conv, att in zip(self.pads, self.convs, self.attn):
|
| 108 |
-
padded = pad(x)
|
| 109 |
-
c = conv(padded)
|
| 110 |
c = F.relu(c)
|
| 111 |
c = c * attn_mask
|
| 112 |
|
| 113 |
-
c_t = c.transpose(1, 2)
|
| 114 |
-
w = att(c_t)
|
| 115 |
w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
|
| 116 |
w = F.softmax(w, dim=1)
|
| 117 |
|
| 118 |
-
pooled = (c_t * w).sum(dim=1)
|
| 119 |
conv_outs.append(pooled)
|
| 120 |
|
| 121 |
-
out = torch.cat(conv_outs, dim=1)
|
| 122 |
return self.dropout(out)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class ProjectionMLP(nn.Module):
|
| 126 |
-
def __init__(self, input_size, hidden_size=256, num_labels=2):
|
| 127 |
-
super().__init__()
|
| 128 |
-
self.layers = nn.Sequential(
|
| 129 |
-
nn.Linear(input_size, hidden_size),
|
| 130 |
-
nn.ReLU(),
|
| 131 |
-
nn.Linear(hidden_size, num_labels)
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
def forward(self, x):
|
| 135 |
-
return self.layers(x)
|
| 136 |
-
|
| 137 |
-
|
| 138 |
class ConcatModelWithRationale(nn.Module):
|
|
|
|
| 139 |
def __init__(self,
|
| 140 |
hatebert_model,
|
| 141 |
additional_model,
|
|
@@ -143,86 +192,153 @@ class ConcatModelWithRationale(nn.Module):
|
|
| 143 |
hidden_size=768,
|
| 144 |
gumbel_temp=0.5,
|
| 145 |
freeze_additional_model=True,
|
| 146 |
-
cnn_num_filters=
|
| 147 |
-
cnn_kernel_sizes=(
|
| 148 |
-
cnn_dropout=0.
|
|
|
|
| 149 |
super().__init__()
|
|
|
|
| 150 |
self.hatebert_model = hatebert_model
|
| 151 |
self.additional_model = additional_model
|
| 152 |
self.projection_mlp = projection_mlp
|
| 153 |
self.gumbel_temp = gumbel_temp
|
| 154 |
self.hidden_size = hidden_size
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
if freeze_additional_model:
|
| 157 |
for param in self.additional_model.parameters():
|
| 158 |
param.requires_grad = False
|
| 159 |
|
| 160 |
-
# selector head (per-token logits)
|
| 161 |
self.selector = nn.Linear(hidden_size, 1)
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
| 175 |
self.msa_out_dim = self.msa_cnn.output_size
|
| 176 |
|
|
|
|
| 177 |
def gumbel_sigmoid_sample(self, logits):
|
| 178 |
noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
|
| 179 |
y = logits + noise
|
| 180 |
return torch.sigmoid(y / self.gumbel_temp)
|
| 181 |
|
| 182 |
-
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask, return_attentions=False):
|
| 183 |
-
# Main text through HateBERT
|
| 184 |
-
hatebert_out = self.hatebert_model(input_ids=input_ids,
|
| 185 |
-
attention_mask=attention_mask,
|
| 186 |
-
output_attentions=return_attentions,
|
| 187 |
-
return_dict=True)
|
| 188 |
-
hatebert_emb = hatebert_out.last_hidden_state # (B, L, H)
|
| 189 |
-
cls_emb = hatebert_emb[:, 0, :] # (B, H)
|
| 190 |
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
with torch.no_grad():
|
| 193 |
-
add_out = self.additional_model(input_ids=additional_input_ids,
|
| 194 |
-
attention_mask=additional_attention_mask,
|
| 195 |
-
return_dict=True)
|
| 196 |
-
rationale_emb = add_out.last_hidden_state # (B, L, H)
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
pooled_rationale = masked_hidden.sum(1) / denom # (B, H)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
logits = self.projection_mlp(concat_emb)
|
| 216 |
|
| 217 |
-
attns =
|
|
|
|
|
|
|
|
|
|
| 218 |
return logits, rationale_probs, selector_logits, attns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
class BaseShield(nn.Module):
|
| 222 |
-
"""
|
| 223 |
-
Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
|
| 224 |
-
and projects to label logits via a small MLP.
|
| 225 |
-
"""
|
| 226 |
def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
|
| 227 |
super().__init__()
|
| 228 |
self.hatebert_model = hatebert_model
|
|
@@ -235,249 +351,189 @@ class BaseShield(nn.Module):
|
|
| 235 |
param.requires_grad = False
|
| 236 |
|
| 237 |
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
|
| 238 |
-
hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask
|
| 239 |
hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
|
| 240 |
|
| 241 |
-
additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask
|
| 242 |
additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
|
| 243 |
|
| 244 |
concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
|
| 245 |
logits = self.projection_mlp(concatenated_embeddings)
|
| 246 |
return logits
|
| 247 |
-
|
| 248 |
-
def load_model_from_hf(model_type="altered"):
|
| 249 |
-
"""
|
| 250 |
-
Load model from Hugging Face Hub
|
| 251 |
|
| 252 |
-
|
| 253 |
-
model_type: Either "altered" or "base" to choose which model to load
|
| 254 |
-
"""
|
| 255 |
|
| 256 |
-
|
|
|
|
|
|
|
| 257 |
repo_id = "seffyehl/BetterShield"
|
| 258 |
-
|
| 259 |
-
# Choose model and config files based on model_type
|
| 260 |
if model_type.lower() == "altered":
|
| 261 |
model_filename = "AlteredShield.pth"
|
| 262 |
-
config_filename = "alter_config.json"
|
| 263 |
elif model_type.lower() == "base":
|
| 264 |
-
model_filename = "
|
| 265 |
-
config_filename = "base_config.json"
|
| 266 |
else:
|
| 267 |
-
raise ValueError(
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
config_path = hf_hub_download(
|
| 276 |
-
repo_id=repo_id,
|
| 277 |
-
filename=config_filename,
|
| 278 |
-
)
|
| 279 |
-
|
| 280 |
-
# Load config
|
| 281 |
-
with open(config_path, 'r') as f:
|
| 282 |
-
config = json.load(f)
|
| 283 |
-
|
| 284 |
-
# Load checkpoint
|
| 285 |
-
checkpoint = torch.load(model_path, map_location='cpu')
|
| 286 |
-
|
| 287 |
-
# Handle nested config structure (base model uses model_config, altered uses flat structure)
|
| 288 |
-
if 'model_config' in config:
|
| 289 |
-
model_config = config['model_config']
|
| 290 |
-
training_config = config.get('training_config', {})
|
| 291 |
else:
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
H = hatebert_model.config.hidden_size
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
model = BaseShield(
|
| 310 |
hatebert_model=hatebert_model,
|
| 311 |
additional_model=rationale_model,
|
| 312 |
projection_mlp=projection_mlp,
|
| 313 |
-
freeze_additional_model=
|
| 314 |
device=device
|
| 315 |
-
)
|
|
|
|
| 316 |
else:
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
model = ConcatModelWithRationale(
|
| 332 |
hatebert_model=hatebert_model,
|
| 333 |
additional_model=rationale_model,
|
| 334 |
projection_mlp=projection_mlp,
|
| 335 |
hidden_size=H,
|
| 336 |
-
freeze_additional_model=
|
| 337 |
cnn_num_filters=cnn_num_filters,
|
| 338 |
cnn_kernel_sizes=cnn_kernel_sizes,
|
| 339 |
cnn_dropout=cnn_dropout
|
| 340 |
-
)
|
| 341 |
-
|
| 342 |
-
# Load state dict with strict checking and error reporting
|
| 343 |
-
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 344 |
-
state_dict_to_load = checkpoint['model_state_dict']
|
| 345 |
-
else:
|
| 346 |
-
state_dict_to_load = checkpoint
|
| 347 |
-
|
| 348 |
-
# Check for missing and unexpected keys
|
| 349 |
-
model_keys = set(model.state_dict().keys())
|
| 350 |
-
checkpoint_keys = set(state_dict_to_load.keys())
|
| 351 |
-
|
| 352 |
-
missing_keys = model_keys - checkpoint_keys
|
| 353 |
-
unexpected_keys = checkpoint_keys - model_keys
|
| 354 |
-
|
| 355 |
-
if missing_keys:
|
| 356 |
-
print(f"WARNING: Missing keys in checkpoint: {missing_keys}")
|
| 357 |
-
if unexpected_keys:
|
| 358 |
-
print(f"WARNING: Unexpected keys in checkpoint: {unexpected_keys}")
|
| 359 |
-
|
| 360 |
-
# Load with strict=False to handle any minor mismatches, but log warnings
|
| 361 |
-
incompatible_keys = model.load_state_dict(state_dict_to_load, strict=True)
|
| 362 |
-
|
| 363 |
-
if incompatible_keys.missing_keys:
|
| 364 |
-
print(f"Missing keys after load: {incompatible_keys.missing_keys}")
|
| 365 |
-
if incompatible_keys.unexpected_keys:
|
| 366 |
-
print(f"Unexpected keys after load: {incompatible_keys.unexpected_keys}")
|
| 367 |
-
|
| 368 |
-
if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
|
| 369 |
-
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
| 370 |
-
print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
|
| 371 |
-
|
| 372 |
-
# CRITICAL: Set to eval mode and ensure no gradient computation
|
| 373 |
-
model.eval()
|
| 374 |
-
|
| 375 |
-
# Disable dropout explicitly by setting training mode to False for all modules
|
| 376 |
-
for module in model.modules():
|
| 377 |
-
if isinstance(module, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)):
|
| 378 |
-
module.p = 0 # Set dropout probability to 0
|
| 379 |
-
|
| 380 |
-
model = model.to(device)
|
| 381 |
-
|
| 382 |
-
# Verify model is in eval mode
|
| 383 |
-
print(f"Model training mode: {model.training}")
|
| 384 |
-
print(f"Dropout layers found: {sum(1 for _ in model.modules() if isinstance(_, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)))}")
|
| 385 |
-
|
| 386 |
-
# Create a unified config dict with max_length at top level for compatibility
|
| 387 |
-
unified_config = config.copy()
|
| 388 |
-
if 'max_length' not in unified_config and 'training_config' in config:
|
| 389 |
-
unified_config['max_length'] = training_config.get('max_length', 128)
|
| 390 |
-
|
| 391 |
-
return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
|
| 392 |
-
|
| 393 |
|
| 394 |
-
|
| 395 |
-
cls_loss = F.cross_entropy(logits, labels)
|
| 396 |
|
| 397 |
-
|
| 398 |
-
if rationale_mask is not None:
|
| 399 |
-
selector_loss = F.binary_cross_entropy_with_logits(selector_logits, rationale_mask.to(selector_logits.device))
|
| 400 |
-
else:
|
| 401 |
-
selector_loss = torch.tensor(0.0, device=cls_loss.device)
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
try:
|
| 407 |
-
last_attn = attns[-1] # (B, H, L, L)
|
| 408 |
-
attn_mass = last_attn.mean(1).mean(1) # (B, L)
|
| 409 |
-
attn_loss = F.mse_loss(attn_mass, rationale_mask.to(attn_mass.device))
|
| 410 |
-
except Exception:
|
| 411 |
-
attn_loss = torch.tensor(0.0, device=cls_loss.device)
|
| 412 |
|
| 413 |
-
|
| 414 |
-
return total_loss, cls_loss.item(), selector_loss.item(), attn_loss.item()
|
| 415 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
-
def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
|
| 418 |
-
device='cpu', max_length=128, model_type="altered"):
|
| 419 |
-
# Ensure model is in eval mode (defensive programming)
|
| 420 |
model.eval()
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
inputs_main = tokenizer_hatebert(
|
| 424 |
text,
|
| 425 |
max_length=max_length,
|
| 426 |
-
padding=
|
| 427 |
truncation=True,
|
| 428 |
-
return_tensors=
|
| 429 |
)
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
rationale if rationale else text,
|
| 433 |
max_length=max_length,
|
| 434 |
-
padding=
|
| 435 |
truncation=True,
|
| 436 |
-
return_tensors=
|
| 437 |
)
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
add_input_ids =
|
| 443 |
-
add_attention_mask =
|
| 444 |
-
|
| 445 |
-
|
| 446 |
with torch.no_grad():
|
|
|
|
| 447 |
if model_type.lower() == "base":
|
| 448 |
logits = model(
|
| 449 |
-
input_ids,
|
| 450 |
-
attention_mask,
|
| 451 |
-
add_input_ids,
|
| 452 |
add_attention_mask
|
| 453 |
)
|
|
|
|
| 454 |
else:
|
| 455 |
-
|
| 456 |
-
input_ids,
|
| 457 |
-
attention_mask,
|
| 458 |
-
add_input_ids,
|
| 459 |
add_attention_mask
|
| 460 |
)
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
prediction = logits.argmax(dim=1).item()
|
| 465 |
confidence = probs[0, prediction].item()
|
| 466 |
-
|
| 467 |
return {
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
| 472 |
}
|
| 473 |
-
|
| 474 |
-
if model_type.lower() != "base":
|
| 475 |
-
result['rationale_scores'] = rationale_probs[0].cpu().numpy() if 'rationale_probs' in locals() else None
|
| 476 |
-
else:
|
| 477 |
-
result['rationale_scores'] = None
|
| 478 |
-
|
| 479 |
-
return result
|
| 480 |
-
|
| 481 |
|
| 482 |
def predict_hatespeech_from_file(
|
| 483 |
text_list,
|
|
@@ -500,11 +556,10 @@ def predict_hatespeech_from_file(
|
|
| 500 |
|
| 501 |
process = psutil.Process(os.getpid())
|
| 502 |
|
| 503 |
-
# 🔥 GPU synchronization BEFORE timing
|
| 504 |
if torch.cuda.is_available():
|
| 505 |
torch.cuda.synchronize()
|
| 506 |
|
| 507 |
-
#
|
| 508 |
with torch.no_grad():
|
| 509 |
_ = predict_text(
|
| 510 |
text=text_list[0],
|
|
@@ -520,10 +575,10 @@ def predict_hatespeech_from_file(
|
|
| 520 |
if torch.cuda.is_available():
|
| 521 |
torch.cuda.synchronize()
|
| 522 |
|
| 523 |
-
# ⏱ Start timer AFTER warmup
|
| 524 |
start_time = time()
|
| 525 |
|
| 526 |
for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
|
|
|
|
| 527 |
result = predict_text(
|
| 528 |
text=text,
|
| 529 |
rationale=rationale,
|
|
@@ -538,28 +593,20 @@ def predict_hatespeech_from_file(
|
|
| 538 |
predictions.append(result['prediction'])
|
| 539 |
all_probs.append(result['probabilities'])
|
| 540 |
|
| 541 |
-
# Reduce monitoring overhead
|
| 542 |
if idx % 10 == 0 or idx == len(text_list) - 1:
|
| 543 |
cpu_percent_list.append(process.cpu_percent())
|
| 544 |
memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
|
| 545 |
|
| 546 |
-
# 🔥 GPU synchronization BEFORE stopping timer
|
| 547 |
if torch.cuda.is_available():
|
| 548 |
torch.cuda.synchronize()
|
| 549 |
|
| 550 |
-
|
| 551 |
-
runtime = end_time - start_time
|
| 552 |
|
| 553 |
print(f"Inference completed for {type(model).__name__}")
|
| 554 |
print(f"Total runtime: {runtime:.4f} seconds")
|
| 555 |
|
| 556 |
-
# ---------------- Metrics ----------------
|
| 557 |
all_probs = np.array(all_probs)
|
| 558 |
|
| 559 |
-
print(f"Probability Mean: {all_probs.mean(axis=0)}")
|
| 560 |
-
print(f"Probability Std: {all_probs.std(axis=0)}")
|
| 561 |
-
print(f"Prediction distribution: {np.bincount(predictions, minlength=2)}")
|
| 562 |
-
|
| 563 |
f1 = f1_score(true_label, predictions, zero_division=0)
|
| 564 |
accuracy = accuracy_score(true_label, predictions)
|
| 565 |
precision = precision_score(true_label, predictions, zero_division=0)
|
|
@@ -572,7 +619,7 @@ def predict_hatespeech_from_file(
|
|
| 572 |
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 573 |
|
| 574 |
return {
|
| 575 |
-
'model_name': type(model).__name__,
|
| 576 |
'f1_score': f1,
|
| 577 |
'accuracy': accuracy,
|
| 578 |
'precision': precision,
|
|
@@ -585,14 +632,10 @@ def predict_hatespeech_from_file(
|
|
| 585 |
'runtime': runtime,
|
| 586 |
'all_probabilities': all_probs.tolist()
|
| 587 |
}
|
| 588 |
-
|
| 589 |
-
|
| 590 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
"""
|
| 594 |
-
# Get prediction
|
| 595 |
-
result = predict_text(
|
| 596 |
text=text,
|
| 597 |
rationale=rationale,
|
| 598 |
model=model,
|
|
@@ -601,6 +644,4 @@ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rat
|
|
| 601 |
device=device,
|
| 602 |
max_length=config.get('max_length', 128),
|
| 603 |
model_type=model_type
|
| 604 |
-
)
|
| 605 |
-
|
| 606 |
-
return result
|
|
|
|
| 1 |
from huggingface_hub import hf_hub_download
|
| 2 |
import torch
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
import torch.nn as nn
|
| 5 |
import json
|
|
|
|
| 8 |
from time import time
|
| 9 |
import psutil
|
| 10 |
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import requests
|
| 13 |
+
import json
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
API_BASE_URL = os.getenv("CLOUDFLARE_API_BASE_URL")
|
| 19 |
+
HEADERS = {"Authorization": f"Bearer {os.getenv('CLOUDFLARE_API_TOKEN')}"}
|
| 20 |
+
MODEL_NAME = os.getenv("CLOUDFLARE_MODEL_NAME")
|
| 21 |
+
|
| 22 |
+
def create_prompt(text):
|
| 23 |
+
return f"""
|
| 24 |
+
You are a content moderation assistant. Identify the list of [rationales] words or phrases from the text that make it hateful,
|
| 25 |
+
list of [derogatory language], and [list of cuss words] and [hate_classification] such as "hateful" or "non-hateful".
|
| 26 |
+
If there are none, respond exactly with [non-hateful] only.
|
| 27 |
+
Output should be in JSON format only. Text: {text}.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def run_mistral_model(model, inputs):
|
| 31 |
+
payload = {"messages": inputs}
|
| 32 |
+
response = requests.post(f"{API_BASE_URL}{model}", headers=HEADERS, json=payload)
|
| 33 |
+
response.raise_for_status()
|
| 34 |
+
return response.json()
|
| 35 |
+
|
| 36 |
+
def flatten_json_string(json_string):
|
| 37 |
+
try:
|
| 38 |
+
obj = json.loads(json_string)
|
| 39 |
+
return json.dumps(obj, separators=(",", ":"))
|
| 40 |
+
except:
|
| 41 |
+
return json_string
|
| 42 |
+
|
| 43 |
+
def get_rationale_from_mistral(text, retries=10):
|
| 44 |
+
for attempt in range(retries):
|
| 45 |
+
try:
|
| 46 |
+
inputs = [{"role": "user", "content": create_prompt(text)}]
|
| 47 |
+
output = run_mistral_model(MODEL_NAME, inputs)
|
| 48 |
+
|
| 49 |
+
result = output.get("result", {})
|
| 50 |
+
response_text = result.get("response", "").strip()
|
| 51 |
+
|
| 52 |
+
if not response_text or response_text.startswith("I cannot"):
|
| 53 |
+
print(f"⚠️ Model returned 'I cannot...' — retrying ({attempt+1}/{retries})")
|
| 54 |
+
continue # retry
|
| 55 |
+
cleaned_rationale = flatten_json_string(response_text).replace("\n", " ").strip()
|
| 56 |
+
return cleaned_rationale
|
| 57 |
+
|
| 58 |
+
except requests.exceptions.HTTPError as e:
|
| 59 |
+
print(f"⚠️ HTTP Error on attempt {attempt+1}: {e}")
|
| 60 |
+
if "RESOURCE_EXHAUSTED" in str(e) or e.response.status_code == 429:
|
| 61 |
+
raise
|
| 62 |
+
|
| 63 |
+
return "non-hateful"
|
| 64 |
+
|
| 65 |
+
def preprocess_rationale_mistral(raw_rationale):
|
| 66 |
+
try:
|
| 67 |
+
x = str(raw_rationale).strip()
|
| 68 |
+
|
| 69 |
+
if x.startswith("```"):
|
| 70 |
+
x = x.replace("```json", "").replace("```", "").strip()
|
| 71 |
+
|
| 72 |
+
x = x.replace('""', '"')
|
| 73 |
+
|
| 74 |
+
# Extract JSON object
|
| 75 |
+
start = x.find("{")
|
| 76 |
+
end = x.rfind("}") + 1
|
| 77 |
+
if start == -1 or end == -1:
|
| 78 |
+
return x.lower()
|
| 79 |
+
|
| 80 |
+
j = json.loads(x[start:end])
|
| 81 |
+
|
| 82 |
+
keys = ["rationales", "derogatory_language", "cuss_words"]
|
| 83 |
+
|
| 84 |
+
if all(k in j and isinstance(j[k], list) and len(j[k]) == 0 for k in keys):
|
| 85 |
+
return "non-hateful"
|
| 86 |
+
|
| 87 |
+
cleaned = {k: j.get(k, []) for k in keys}
|
| 88 |
+
return json.dumps(cleaned).lower()
|
| 89 |
|
| 90 |
+
except Exception:
|
| 91 |
+
return str(raw_rationale).lower()
|
| 92 |
+
|
| 93 |
class TemporalCNN(nn.Module):
|
| 94 |
+
def __init__(self, input_dim=768, num_filters=32, kernel_sizes=(3,4,5), dropout=0.3):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
super().__init__()
|
| 96 |
self.input_dim = input_dim
|
| 97 |
self.num_filters = num_filters
|
| 98 |
self.kernel_sizes = kernel_sizes
|
|
|
|
|
|
|
| 99 |
self.convs = nn.ModuleList([
|
| 100 |
+
nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k//2)
|
| 101 |
for k in kernel_sizes
|
| 102 |
])
|
| 103 |
self.dropout = nn.Dropout(dropout)
|
| 104 |
|
| 105 |
def forward(self, sequence_embeddings, attention_mask=None):
|
| 106 |
+
x = sequence_embeddings.transpose(1, 2).contiguous()
|
| 107 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
pooled_outputs = []
|
| 109 |
for conv in self.convs:
|
| 110 |
+
conv_out = conv(x)
|
| 111 |
conv_out = F.relu(conv_out)
|
| 112 |
L_out = conv_out.size(2)
|
| 113 |
+
|
| 114 |
if attention_mask is not None:
|
|
|
|
| 115 |
mask = attention_mask.float()
|
| 116 |
if mask.size(1) != L_out:
|
| 117 |
mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
|
| 118 |
mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
|
| 119 |
+
|
|
|
|
| 120 |
neg_inf = torch.finfo(conv_out.dtype).min / 2
|
| 121 |
+
max_masked = torch.where(mask.bool(), conv_out, neg_inf*torch.ones_like(conv_out))
|
| 122 |
max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
|
| 123 |
+
|
|
|
|
| 124 |
sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
|
| 125 |
denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
|
| 126 |
mean_pooled = sum_masked / denom # (B, num_filters)
|
| 127 |
else:
|
| 128 |
max_pooled = torch.max(conv_out, dim=2)[0]
|
| 129 |
mean_pooled = conv_out.mean(dim=2)
|
| 130 |
+
|
| 131 |
pooled_outputs.append(max_pooled)
|
| 132 |
pooled_outputs.append(mean_pooled)
|
| 133 |
+
|
| 134 |
+
out = torch.cat(pooled_outputs, dim=1)
|
| 135 |
out = self.dropout(out)
|
| 136 |
return out
|
| 137 |
|
| 138 |
|
| 139 |
class MultiScaleAttentionCNN(nn.Module):
|
| 140 |
+
def __init__(self, hidden_size=768, num_filters=32, kernel_sizes=(3,4,5), dropout=0.3):
|
| 141 |
super().__init__()
|
| 142 |
|
| 143 |
self.hidden_size = hidden_size
|
| 144 |
self.kernel_sizes = kernel_sizes
|
| 145 |
|
| 146 |
self.convs = nn.ModuleList()
|
| 147 |
+
self.pads = nn.ModuleList()
|
| 148 |
|
| 149 |
for k in self.kernel_sizes:
|
| 150 |
+
pad_left = (k - 1) // 2
|
| 151 |
pad_right = k - 1 - pad_left
|
| 152 |
+
|
| 153 |
self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
|
| 154 |
+
|
| 155 |
+
self.convs.append(
|
| 156 |
+
nn.Conv1d(hidden_size, num_filters, kernel_size=k, padding=0)
|
| 157 |
+
)
|
| 158 |
|
| 159 |
self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
|
| 160 |
self.output_size = num_filters * len(self.kernel_sizes)
|
| 161 |
self.dropout = nn.Dropout(dropout)
|
| 162 |
|
| 163 |
def forward(self, hidden_states, mask):
|
| 164 |
+
x = hidden_states.transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
attn_mask = mask.unsqueeze(1).float()
|
| 166 |
|
| 167 |
conv_outs = []
|
| 168 |
|
| 169 |
for pad, conv, att in zip(self.pads, self.convs, self.attn):
|
| 170 |
+
padded = pad(x)
|
| 171 |
+
c = conv(padded)
|
| 172 |
c = F.relu(c)
|
| 173 |
c = c * attn_mask
|
| 174 |
|
| 175 |
+
c_t = c.transpose(1, 2)
|
| 176 |
+
w = att(c_t)
|
| 177 |
w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
|
| 178 |
w = F.softmax(w, dim=1)
|
| 179 |
|
| 180 |
+
pooled = (c_t * w).sum(dim=1)
|
| 181 |
conv_outs.append(pooled)
|
| 182 |
|
| 183 |
+
out = torch.cat(conv_outs, dim=1)
|
| 184 |
return self.dropout(out)
|
| 185 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
class ConcatModelWithRationale(nn.Module):
|
| 187 |
+
|
| 188 |
def __init__(self,
|
| 189 |
hatebert_model,
|
| 190 |
additional_model,
|
|
|
|
| 192 |
hidden_size=768,
|
| 193 |
gumbel_temp=0.5,
|
| 194 |
freeze_additional_model=True,
|
| 195 |
+
cnn_num_filters=128,
|
| 196 |
+
cnn_kernel_sizes=(3,4,5),
|
| 197 |
+
cnn_dropout=0.3):
|
| 198 |
+
|
| 199 |
super().__init__()
|
| 200 |
+
|
| 201 |
self.hatebert_model = hatebert_model
|
| 202 |
self.additional_model = additional_model
|
| 203 |
self.projection_mlp = projection_mlp
|
| 204 |
self.gumbel_temp = gumbel_temp
|
| 205 |
self.hidden_size = hidden_size
|
| 206 |
|
| 207 |
+
for param in self.hatebert_model.embeddings.parameters():
|
| 208 |
+
param.requires_grad = False
|
| 209 |
+
|
| 210 |
+
for layer in self.hatebert_model.encoder.layer[:8]:
|
| 211 |
+
for param in layer.parameters():
|
| 212 |
+
param.requires_grad = False
|
| 213 |
if freeze_additional_model:
|
| 214 |
for param in self.additional_model.parameters():
|
| 215 |
param.requires_grad = False
|
| 216 |
|
|
|
|
| 217 |
self.selector = nn.Linear(hidden_size, 1)
|
| 218 |
|
| 219 |
+
self.temporal_cnn = TemporalCNN(
|
| 220 |
+
input_dim=hidden_size,
|
| 221 |
+
num_filters=cnn_num_filters,
|
| 222 |
+
kernel_sizes=cnn_kernel_sizes,
|
| 223 |
+
dropout=cnn_dropout
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
|
| 227 |
|
| 228 |
+
self.msa_cnn = MultiScaleAttentionCNN(
|
| 229 |
+
hidden_size=hidden_size,
|
| 230 |
+
num_filters=cnn_num_filters,
|
| 231 |
+
kernel_sizes=cnn_kernel_sizes,
|
| 232 |
+
dropout=cnn_dropout
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
self.msa_out_dim = self.msa_cnn.output_size
|
| 236 |
|
| 237 |
+
|
| 238 |
def gumbel_sigmoid_sample(self, logits):
|
| 239 |
noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
|
| 240 |
y = logits + noise
|
| 241 |
return torch.sigmoid(y / self.gumbel_temp)
|
| 242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
+
def forward(self,
|
| 245 |
+
input_ids,
|
| 246 |
+
attention_mask,
|
| 247 |
+
additional_input_ids,
|
| 248 |
+
additional_attention_mask,
|
| 249 |
+
return_attentions=False):
|
| 250 |
+
hatebert_out = self.hatebert_model(
|
| 251 |
+
input_ids=input_ids,
|
| 252 |
+
attention_mask=attention_mask,
|
| 253 |
+
output_attentions=return_attentions,
|
| 254 |
+
return_dict=True
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
hatebert_emb = hatebert_out.last_hidden_state
|
| 258 |
+
cls_emb = hatebert_emb[:, 0, :]
|
| 259 |
+
|
| 260 |
with torch.no_grad():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
add_out = self.additional_model(
|
| 263 |
+
input_ids=additional_input_ids,
|
| 264 |
+
attention_mask=additional_attention_mask,
|
| 265 |
+
return_dict=True
|
| 266 |
+
)
|
| 267 |
|
| 268 |
+
rationale_emb = add_out.last_hidden_state
|
| 269 |
+
|
| 270 |
+
selector_logits = self.selector(hatebert_emb).squeeze(-1)
|
|
|
|
| 271 |
|
| 272 |
+
if self.training:
|
| 273 |
+
rationale_probs = self.gumbel_sigmoid_sample(selector_logits)
|
| 274 |
+
else:
|
| 275 |
+
rationale_probs = torch.sigmoid(selector_logits)
|
| 276 |
+
|
| 277 |
+
rationale_probs = rationale_probs * attention_mask.float()
|
| 278 |
|
| 279 |
+
masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
|
| 280 |
+
denom = rationale_probs.sum(dim=1).unsqueeze(-1).clamp_min(1e-6)
|
| 281 |
+
pooled_rationale = masked_hidden.sum(dim=1) / denom
|
| 282 |
+
|
| 283 |
+
temporal_features = self.temporal_cnn(
|
| 284 |
+
hatebert_emb,
|
| 285 |
+
attention_mask
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
rationale_features = self.msa_cnn(
|
| 289 |
+
rationale_emb,
|
| 290 |
+
additional_attention_mask
|
| 291 |
+
)
|
| 292 |
+
concat_emb = torch.cat(
|
| 293 |
+
(cls_emb,
|
| 294 |
+
temporal_features,
|
| 295 |
+
rationale_features,
|
| 296 |
+
pooled_rationale),
|
| 297 |
+
dim=1
|
| 298 |
+
)
|
| 299 |
|
| 300 |
logits = self.projection_mlp(concat_emb)
|
| 301 |
|
| 302 |
+
attns = None
|
| 303 |
+
if return_attentions and hasattr(hatebert_out, "attentions"):
|
| 304 |
+
attns = hatebert_out.attentions
|
| 305 |
+
|
| 306 |
return logits, rationale_probs, selector_logits, attns
|
| 307 |
+
|
| 308 |
+
class ProjectionMLP(nn.Module):
|
| 309 |
+
def __init__(self, input_size, hidden_size=128, num_labels=2):
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
self.layers = nn.Sequential(
|
| 313 |
+
nn.Linear(input_size, 512),
|
| 314 |
+
nn.LayerNorm(512),
|
| 315 |
+
nn.ReLU(),
|
| 316 |
+
nn.Dropout(0.3),
|
| 317 |
+
|
| 318 |
+
nn.Linear(512, hidden_size),
|
| 319 |
+
nn.ReLU(),
|
| 320 |
+
nn.Dropout(0.3),
|
| 321 |
+
|
| 322 |
+
nn.Linear(hidden_size, num_labels)
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
def forward(self, x):
|
| 326 |
+
return self.layers(x)
|
| 327 |
|
| 328 |
|
| 329 |
+
class ProjectionMLPBase(nn.Module):
|
| 330 |
+
def __init__(self, input_size, output_size):
|
| 331 |
+
super(ProjectionMLPBase, self).__init__()
|
| 332 |
+
self.layers = nn.Sequential(
|
| 333 |
+
nn.Linear(input_size, output_size),
|
| 334 |
+
nn.ReLU(),
|
| 335 |
+
nn.Linear(output_size, 2)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def forward(self, x):
|
| 339 |
+
return self.layers(x)
|
| 340 |
+
|
| 341 |
class BaseShield(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
|
| 343 |
super().__init__()
|
| 344 |
self.hatebert_model = hatebert_model
|
|
|
|
| 351 |
param.requires_grad = False
|
| 352 |
|
| 353 |
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
|
| 354 |
+
hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 355 |
hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
|
| 356 |
|
| 357 |
+
additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
|
| 358 |
additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
|
| 359 |
|
| 360 |
concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
|
| 361 |
logits = self.projection_mlp(concatenated_embeddings)
|
| 362 |
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
+
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
def load_model_from_hf(model_type="altered"):
|
| 367 |
+
|
| 368 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 369 |
repo_id = "seffyehl/BetterShield"
|
| 370 |
+
|
|
|
|
| 371 |
if model_type.lower() == "altered":
|
| 372 |
model_filename = "AlteredShield.pth"
|
|
|
|
| 373 |
elif model_type.lower() == "base":
|
| 374 |
+
model_filename = "BaseShield.pth"
|
|
|
|
| 375 |
else:
|
| 376 |
+
raise ValueError("model_type must be 'base' or 'altered'")
|
| 377 |
+
|
| 378 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
|
| 379 |
+
|
| 380 |
+
checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
|
| 381 |
+
|
| 382 |
+
if "model_state_dict" in checkpoint:
|
| 383 |
+
state_dict = checkpoint["model_state_dict"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
else:
|
| 385 |
+
state_dict = checkpoint
|
| 386 |
+
|
| 387 |
+
hatebert_name = "GroNLP/hateBERT"
|
| 388 |
+
rationale_name = "bert-base-uncased"
|
| 389 |
+
|
| 390 |
+
hatebert_model = AutoModel.from_pretrained(hatebert_name)
|
| 391 |
+
rationale_model = AutoModel.from_pretrained(rationale_name)
|
| 392 |
+
|
| 393 |
+
tokenizer_hatebert = AutoTokenizer.from_pretrained(hatebert_name)
|
| 394 |
+
tokenizer_rationale = AutoTokenizer.from_pretrained(rationale_name)
|
| 395 |
+
|
| 396 |
H = hatebert_model.config.hidden_size
|
| 397 |
+
first_layer_weight = state_dict["projection_mlp.layers.0.weight"]
|
| 398 |
+
second_layer_weight = state_dict["projection_mlp.layers.4.weight"]
|
| 399 |
+
classifier_weight = state_dict["projection_mlp.layers.7.weight"]
|
| 400 |
+
|
| 401 |
+
input_dim = first_layer_weight.shape[1]
|
| 402 |
+
hidden_dim = second_layer_weight.shape[0]
|
| 403 |
+
num_labels = classifier_weight.shape[0]
|
| 404 |
+
|
| 405 |
+
temporal_keys = [k for k in state_dict if k.startswith("temporal_cnn.convs")]
|
| 406 |
+
|
| 407 |
+
is_altered = len(temporal_keys) > 0
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
if not is_altered or model_type.lower() == "base":
|
| 411 |
+
|
| 412 |
+
projection_mlp = ProjectionMLPBase(
|
| 413 |
+
input_size=input_dim,
|
| 414 |
+
output_size=hidden_dim
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
model = BaseShield(
|
| 418 |
hatebert_model=hatebert_model,
|
| 419 |
additional_model=rationale_model,
|
| 420 |
projection_mlp=projection_mlp,
|
| 421 |
+
freeze_additional_model=True,
|
| 422 |
device=device
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
else:
|
| 426 |
+
conv_weights = [
|
| 427 |
+
v for k, v in state_dict.items()
|
| 428 |
+
if k.startswith("temporal_cnn.convs") and k.endswith("weight")
|
| 429 |
+
]
|
| 430 |
+
|
| 431 |
+
cnn_num_filters = conv_weights[0].shape[0]
|
| 432 |
+
cnn_kernel_sizes = tuple(w.shape[2] for w in conv_weights)
|
| 433 |
+
cnn_dropout = 0.3
|
| 434 |
+
projection_mlp = ProjectionMLP(
|
| 435 |
+
input_size=input_dim,
|
| 436 |
+
hidden_size=hidden_dim,
|
| 437 |
+
num_labels=num_labels
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
model = ConcatModelWithRationale(
|
| 441 |
hatebert_model=hatebert_model,
|
| 442 |
additional_model=rationale_model,
|
| 443 |
projection_mlp=projection_mlp,
|
| 444 |
hidden_size=H,
|
| 445 |
+
freeze_additional_model=True,
|
| 446 |
cnn_num_filters=cnn_num_filters,
|
| 447 |
cnn_kernel_sizes=cnn_kernel_sizes,
|
| 448 |
cnn_dropout=cnn_dropout
|
| 449 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
+
model.load_state_dict(state_dict, strict=True)
|
|
|
|
| 452 |
|
| 453 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
+
config = {
|
| 456 |
+
"max_length": 128
|
| 457 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
return model, tokenizer_hatebert, tokenizer_rationale, config, device
|
|
|
|
| 460 |
|
| 461 |
+
def predict_text(
|
| 462 |
+
text,
|
| 463 |
+
rationale,
|
| 464 |
+
model,
|
| 465 |
+
tokenizer_hatebert,
|
| 466 |
+
tokenizer_rationale,
|
| 467 |
+
device="cpu",
|
| 468 |
+
max_length=128,
|
| 469 |
+
model_type="altered"
|
| 470 |
+
):
|
| 471 |
|
|
|
|
|
|
|
|
|
|
| 472 |
model.eval()
|
| 473 |
+
|
| 474 |
+
main_inputs = tokenizer_hatebert(
|
|
|
|
| 475 |
text,
|
| 476 |
max_length=max_length,
|
| 477 |
+
padding="max_length",
|
| 478 |
truncation=True,
|
| 479 |
+
return_tensors="pt"
|
| 480 |
)
|
| 481 |
+
|
| 482 |
+
rationale_inputs = tokenizer_rationale(
|
| 483 |
+
rationale if rationale else text,
|
| 484 |
max_length=max_length,
|
| 485 |
+
padding="max_length",
|
| 486 |
truncation=True,
|
| 487 |
+
return_tensors="pt"
|
| 488 |
)
|
| 489 |
+
|
| 490 |
+
input_ids = main_inputs["input_ids"].to(device)
|
| 491 |
+
attention_mask = main_inputs["attention_mask"].to(device)
|
| 492 |
+
|
| 493 |
+
add_input_ids = rationale_inputs["input_ids"].to(device)
|
| 494 |
+
add_attention_mask = rationale_inputs["attention_mask"].to(device)
|
| 495 |
+
|
| 496 |
+
tokens = tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
|
| 497 |
with torch.no_grad():
|
| 498 |
+
|
| 499 |
if model_type.lower() == "base":
|
| 500 |
logits = model(
|
| 501 |
+
input_ids,
|
| 502 |
+
attention_mask,
|
| 503 |
+
add_input_ids,
|
| 504 |
add_attention_mask
|
| 505 |
)
|
| 506 |
+
rationale_scores = None
|
| 507 |
else:
|
| 508 |
+
outputs = model(
|
| 509 |
+
input_ids,
|
| 510 |
+
attention_mask,
|
| 511 |
+
add_input_ids,
|
| 512 |
add_attention_mask
|
| 513 |
)
|
| 514 |
+
|
| 515 |
+
if isinstance(outputs, tuple) and len(outputs) == 4:
|
| 516 |
+
logits, rationale_probs, _, _ = outputs
|
| 517 |
+
rationale_scores = rationale_probs[0].cpu().numpy()
|
| 518 |
+
else:
|
| 519 |
+
raise ValueError(f"Unexpected number of outputs from model: {len(outputs)}")
|
| 520 |
+
|
| 521 |
+
rationale_scores = rationale_probs[0].cpu().numpy()
|
| 522 |
+
|
| 523 |
+
probs = F.softmax(logits, dim=1)
|
| 524 |
+
|
| 525 |
+
if torch.isnan(probs).any() or torch.isinf(probs).any():
|
| 526 |
+
probs = torch.ones_like(logits) / logits.size(1)
|
| 527 |
+
|
| 528 |
prediction = logits.argmax(dim=1).item()
|
| 529 |
confidence = probs[0, prediction].item()
|
|
|
|
| 530 |
return {
|
| 531 |
+
"prediction": prediction,
|
| 532 |
+
"confidence": confidence,
|
| 533 |
+
"probabilities": probs[0].cpu().numpy(),
|
| 534 |
+
"tokens": tokens,
|
| 535 |
+
"rationale_scores": rationale_scores
|
| 536 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
def predict_hatespeech_from_file(
|
| 539 |
text_list,
|
|
|
|
| 556 |
|
| 557 |
process = psutil.Process(os.getpid())
|
| 558 |
|
|
|
|
| 559 |
if torch.cuda.is_available():
|
| 560 |
torch.cuda.synchronize()
|
| 561 |
|
| 562 |
+
# warmup
|
| 563 |
with torch.no_grad():
|
| 564 |
_ = predict_text(
|
| 565 |
text=text_list[0],
|
|
|
|
| 575 |
if torch.cuda.is_available():
|
| 576 |
torch.cuda.synchronize()
|
| 577 |
|
|
|
|
| 578 |
start_time = time()
|
| 579 |
|
| 580 |
for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
|
| 581 |
+
|
| 582 |
result = predict_text(
|
| 583 |
text=text,
|
| 584 |
rationale=rationale,
|
|
|
|
| 593 |
predictions.append(result['prediction'])
|
| 594 |
all_probs.append(result['probabilities'])
|
| 595 |
|
|
|
|
| 596 |
if idx % 10 == 0 or idx == len(text_list) - 1:
|
| 597 |
cpu_percent_list.append(process.cpu_percent())
|
| 598 |
memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
|
| 599 |
|
|
|
|
| 600 |
if torch.cuda.is_available():
|
| 601 |
torch.cuda.synchronize()
|
| 602 |
|
| 603 |
+
runtime = time() - start_time
|
|
|
|
| 604 |
|
| 605 |
print(f"Inference completed for {type(model).__name__}")
|
| 606 |
print(f"Total runtime: {runtime:.4f} seconds")
|
| 607 |
|
|
|
|
| 608 |
all_probs = np.array(all_probs)
|
| 609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
f1 = f1_score(true_label, predictions, zero_division=0)
|
| 611 |
accuracy = accuracy_score(true_label, predictions)
|
| 612 |
precision = precision_score(true_label, predictions, zero_division=0)
|
|
|
|
| 619 |
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 620 |
|
| 621 |
return {
|
| 622 |
+
'model_name': type(model).__name__,
|
| 623 |
'f1_score': f1,
|
| 624 |
'accuracy': accuracy,
|
| 625 |
'precision': precision,
|
|
|
|
| 632 |
'runtime': runtime,
|
| 633 |
'all_probabilities': all_probs.tolist()
|
| 634 |
}
|
| 635 |
+
|
|
|
|
| 636 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 637 |
+
|
| 638 |
+
return predict_text(
|
|
|
|
|
|
|
|
|
|
| 639 |
text=text,
|
| 640 |
rationale=rationale,
|
| 641 |
model=model,
|
|
|
|
| 644 |
device=device,
|
| 645 |
max_length=config.get('max_length', 128),
|
| 646 |
model_type=model_type
|
| 647 |
+
)
|
|
|
|
|
|