Spaces:
Running
Running
fix
Browse files- notebooks/combined-baseline (1).ipynb +0 -0
- notebooks/proposed-new-model.ipynb +0 -0
- src/app.py +17 -10
- src/hatespeech_model.py +539 -320
notebooks/combined-baseline (1).ipynb
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
notebooks/proposed-new-model.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
-
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file,
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import plotly.express as px
|
| 5 |
import pandas as pd
|
|
@@ -154,12 +154,22 @@ classify_button = st.button("🔍 Analyze Text", type="primary", use_container_w
|
|
| 154 |
|
| 155 |
if classify_button:
|
| 156 |
if user_input and user_input.strip():
|
| 157 |
-
with st.spinner('🔄
|
| 158 |
-
#
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
enhanced_model_result = predict_hatespeech(
|
| 161 |
text=user_input,
|
| 162 |
-
rationale=
|
| 163 |
model=enhanced_model,
|
| 164 |
tokenizer_hatebert=enhanced_tokenizer_hatebert,
|
| 165 |
tokenizer_rationale=enhanced_tokenizer_rationale,
|
|
@@ -169,10 +179,11 @@ if classify_button:
|
|
| 169 |
)
|
| 170 |
enhanced_end = time.time()
|
| 171 |
|
|
|
|
| 172 |
base_start = time.time()
|
| 173 |
base_model_result = predict_hatespeech(
|
| 174 |
text=user_input,
|
| 175 |
-
rationale=
|
| 176 |
model=base_model,
|
| 177 |
tokenizer_hatebert=base_tokenizer_hatebert,
|
| 178 |
tokenizer_rationale=base_tokenizer_rationale,
|
|
@@ -358,10 +369,6 @@ if classify_button:
|
|
| 358 |
}
|
| 359 |
})
|
| 360 |
if is_file_uploader_visible and uploaded_file is not None:
|
| 361 |
-
st.markdown(f"**Filename:** {uploaded_file.name}")
|
| 362 |
-
st.markdown(f"**Size:** {uploaded_file.size / 1024:.2f} KB")
|
| 363 |
-
file_rows = len(file_content)
|
| 364 |
-
st.metric("Rows in File", file_rows)
|
| 365 |
st.markdown("**Preview:**")
|
| 366 |
st.dataframe(file_content.head(3), use_container_width=True)
|
| 367 |
with st.spinner('🔄 Analyzing file with both models... This may take a while for large files.'):
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from hatespeech_model import predict_hatespeech, load_model_from_hf, predict_hatespeech_from_file, get_rationale_from_mistral, preprocess_rationale_mistral
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import plotly.express as px
|
| 5 |
import pandas as pd
|
|
|
|
| 154 |
|
| 155 |
if classify_button:
|
| 156 |
if user_input and user_input.strip():
|
| 157 |
+
with st.spinner('🔄 Generating rationale from Mistral AI...'):
|
| 158 |
+
# --- Step 1: Get rationale from Mistral ---
|
| 159 |
+
try:
|
| 160 |
+
raw_rationale = get_rationale_from_mistral(user_input)
|
| 161 |
+
cleaned_rationale = preprocess_rationale_mistral(raw_rationale)
|
| 162 |
+
print(f"Raw rationale from Mistral: {raw_rationale}")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
st.error(f"❌ Error generating/processing rationale: {str(e)}")
|
| 165 |
+
cleaned_rationale = user_input # fallback to raw input
|
| 166 |
+
|
| 167 |
+
with st.spinner('🔄 Analyzing text with models...'):
|
| 168 |
+
# Run enhanced model
|
| 169 |
+
enhanced_start = time.time()
|
| 170 |
enhanced_model_result = predict_hatespeech(
|
| 171 |
text=user_input,
|
| 172 |
+
rationale=cleaned_rationale, # use cleaned rationale
|
| 173 |
model=enhanced_model,
|
| 174 |
tokenizer_hatebert=enhanced_tokenizer_hatebert,
|
| 175 |
tokenizer_rationale=enhanced_tokenizer_rationale,
|
|
|
|
| 179 |
)
|
| 180 |
enhanced_end = time.time()
|
| 181 |
|
| 182 |
+
# Run base model
|
| 183 |
base_start = time.time()
|
| 184 |
base_model_result = predict_hatespeech(
|
| 185 |
text=user_input,
|
| 186 |
+
rationale=cleaned_rationale, # use cleaned rationale
|
| 187 |
model=base_model,
|
| 188 |
tokenizer_hatebert=base_tokenizer_hatebert,
|
| 189 |
tokenizer_rationale=base_tokenizer_rationale,
|
|
|
|
| 369 |
}
|
| 370 |
})
|
| 371 |
if is_file_uploader_visible and uploaded_file is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
st.markdown("**Preview:**")
|
| 373 |
st.dataframe(file_content.head(3), use_container_width=True)
|
| 374 |
with st.spinner('🔄 Analyzing file with both models... This may take a while for large files.'):
|
src/hatespeech_model.py
CHANGED
|
@@ -9,164 +9,339 @@ from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_sc
|
|
| 9 |
from time import time
|
| 10 |
import psutil
|
| 11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Model Architecture Classes
|
| 14 |
class TemporalCNN(nn.Module):
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
super().__init__()
|
|
|
|
|
|
|
| 17 |
self.kernel_sizes = kernel_sizes
|
| 18 |
-
|
|
|
|
| 19 |
self.convs = nn.ModuleList([
|
| 20 |
-
nn.Conv1d(
|
| 21 |
-
for
|
| 22 |
])
|
| 23 |
self.dropout = nn.Dropout(dropout)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
x =
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
for
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class MultiScaleAttentionCNN(nn.Module):
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
class ProjectionMLP(nn.Module):
|
| 70 |
-
def __init__(self, input_size, hidden_size, num_labels):
|
| 71 |
super().__init__()
|
| 72 |
self.layers = nn.Sequential(
|
| 73 |
nn.Linear(input_size, hidden_size),
|
| 74 |
nn.ReLU(),
|
| 75 |
nn.Linear(hidden_size, num_labels)
|
| 76 |
)
|
| 77 |
-
|
| 78 |
def forward(self, x):
|
| 79 |
return self.layers(x)
|
| 80 |
|
| 81 |
-
class GumbelTokenSelector(nn.Module):
|
| 82 |
-
def __init__(self, hidden_size, tau=1.0):
|
| 83 |
-
super().__init__()
|
| 84 |
-
self.tau = tau
|
| 85 |
-
self.proj = nn.Linear(hidden_size * 2, 1)
|
| 86 |
-
|
| 87 |
-
def forward(self, token_embeddings, cls_embedding, training=True):
|
| 88 |
-
B, L, H = token_embeddings.size()
|
| 89 |
-
cls_exp = cls_embedding.unsqueeze(1).expand(-1, L, -1)
|
| 90 |
-
x = torch.cat([token_embeddings, cls_exp], dim=-1)
|
| 91 |
-
logits = self.proj(x).squeeze(-1)
|
| 92 |
-
|
| 93 |
-
if training:
|
| 94 |
-
probs = F.gumbel_softmax(
|
| 95 |
-
torch.stack([logits, torch.zeros_like(logits)], dim=-1),
|
| 96 |
-
tau=self.tau,
|
| 97 |
-
hard=False
|
| 98 |
-
)[..., 0]
|
| 99 |
-
else:
|
| 100 |
-
probs = torch.sigmoid(logits)
|
| 101 |
-
return probs, logits
|
| 102 |
|
| 103 |
-
class
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
super().__init__()
|
| 110 |
self.hatebert_model = hatebert_model
|
| 111 |
self.additional_model = additional_model
|
| 112 |
self.projection_mlp = projection_mlp
|
| 113 |
-
self.
|
| 114 |
-
|
|
|
|
| 115 |
if freeze_additional_model:
|
| 116 |
for param in self.additional_model.parameters():
|
| 117 |
param.requires_grad = False
|
| 118 |
-
|
| 119 |
-
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
|
| 120 |
-
hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask)
|
| 121 |
-
hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
|
| 122 |
-
hatebert_embeddings = torch.nn.LayerNorm(hatebert_embeddings.size()[1:]).to(self.device)(hatebert_embeddings.to(self.device)).to(self.device)
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
# Return 4 values to match ConcatModel interface (rationale_probs, selector_logits, attentions are None)
|
| 132 |
-
return projected_embeddings
|
| 133 |
|
| 134 |
-
class
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
super().__init__()
|
| 137 |
self.hatebert_model = hatebert_model
|
| 138 |
self.additional_model = additional_model
|
| 139 |
-
self.temporal_cnn = temporal_cnn
|
| 140 |
-
self.msa_cnn = msa_cnn
|
| 141 |
-
self.selector = selector
|
| 142 |
self.projection_mlp = projection_mlp
|
|
|
|
| 143 |
|
| 144 |
if freeze_additional_model:
|
| 145 |
-
for
|
| 146 |
-
|
| 147 |
-
if freeze_hatebert:
|
| 148 |
-
for p in self.hatebert_model.parameters():
|
| 149 |
-
p.requires_grad = False
|
| 150 |
|
| 151 |
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
with torch.no_grad():
|
| 163 |
-
add_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask)
|
| 164 |
-
add_seq = add_outputs.last_hidden_state
|
| 165 |
-
|
| 166 |
-
msa_feat = self.msa_cnn(add_seq, additional_attention_mask)
|
| 167 |
-
concat = torch.cat([cls_emb, temporal_feat, msa_feat, H_r], dim=1)
|
| 168 |
-
logits = self.projection_mlp(concat)
|
| 169 |
-
return logits, token_probs, token_logits, hate_outputs.attentions if hasattr(hate_outputs, "attentions") else None
|
| 170 |
|
| 171 |
def load_model_from_hf(model_type="altered"):
|
| 172 |
"""
|
|
@@ -178,14 +353,13 @@ def load_model_from_hf(model_type="altered"):
|
|
| 178 |
|
| 179 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 180 |
repo_id = "seffyehl/BetterShield"
|
| 181 |
-
# repo_type = "e5912f6e8c34a10629cfd5a7971ac71ac76d0e9d"
|
| 182 |
|
| 183 |
# Choose model and config files based on model_type
|
| 184 |
if model_type.lower() == "altered":
|
| 185 |
model_filename = "AlteredShield.pth"
|
| 186 |
config_filename = "alter_config.json"
|
| 187 |
elif model_type.lower() == "base":
|
| 188 |
-
model_filename = "
|
| 189 |
config_filename = "base_config.json"
|
| 190 |
else:
|
| 191 |
raise ValueError(f"model_type must be 'altered' or 'base', got '{model_type}'")
|
|
@@ -193,22 +367,24 @@ def load_model_from_hf(model_type="altered"):
|
|
| 193 |
# Download files
|
| 194 |
model_path = hf_hub_download(
|
| 195 |
repo_id=repo_id,
|
| 196 |
-
# revision=repo_type,
|
| 197 |
filename=model_filename
|
| 198 |
)
|
| 199 |
|
| 200 |
config_path = hf_hub_download(
|
| 201 |
repo_id=repo_id,
|
| 202 |
filename=config_filename,
|
| 203 |
-
# revision=repo_type
|
| 204 |
)
|
| 205 |
|
| 206 |
# Load config
|
| 207 |
with open(config_path, 'r') as f:
|
| 208 |
config = json.load(f)
|
| 209 |
|
| 210 |
-
# Load checkpoint
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
# Handle nested config structure (base model uses model_config, altered uses flat structure)
|
| 214 |
if 'model_config' in config:
|
|
@@ -225,50 +401,144 @@ def load_model_from_hf(model_type="altered"):
|
|
| 225 |
tokenizer_hatebert = AutoTokenizer.from_pretrained(model_config['hatebert_model'])
|
| 226 |
tokenizer_rationale = AutoTokenizer.from_pretrained(model_config['rationale_model'])
|
| 227 |
|
| 228 |
-
# Rebuild architecture based on model type
|
| 229 |
H = hatebert_model.config.hidden_size
|
| 230 |
max_length = training_config.get('max_length', 128)
|
| 231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if model_type.lower() == "base":
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
proj_input_dim = H * 2 # 1536
|
| 236 |
-
# The saved model uses 512, not what's in projection_config
|
| 237 |
-
adapter_dim = 512 # hardcoded to match saved weights
|
| 238 |
-
projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim,
|
| 239 |
-
num_labels=2)
|
| 240 |
-
|
| 241 |
model = BaseShield(
|
| 242 |
hatebert_model=hatebert_model,
|
| 243 |
additional_model=rationale_model,
|
| 244 |
projection_mlp=projection_mlp,
|
| 245 |
-
freeze_additional_model=
|
| 246 |
device=device
|
| 247 |
).to(device)
|
| 248 |
else:
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
|
|
|
| 263 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
| 266 |
print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
|
| 267 |
-
|
| 268 |
-
|
| 269 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
model = model.to(device)
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
# Create a unified config dict with max_length at top level for compatibility
|
| 273 |
unified_config = config.copy()
|
| 274 |
if 'max_length' not in unified_config and 'training_config' in config:
|
|
@@ -276,26 +546,33 @@ def load_model_from_hf(model_type="altered"):
|
|
| 276 |
|
| 277 |
return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
|
| 280 |
device='cpu', max_length=128, model_type="altered"):
|
| 281 |
-
|
| 282 |
-
Predict hate speech for a given text and rationale
|
| 283 |
-
|
| 284 |
-
Args:
|
| 285 |
-
text: Input text to classify
|
| 286 |
-
rationale: Rationale/explanation text
|
| 287 |
-
model: Loaded model
|
| 288 |
-
tokenizer_hatebert: HateBERT tokenizer
|
| 289 |
-
tokenizer_rationale: Rationale model tokenizer
|
| 290 |
-
device: 'cpu' or 'cuda'
|
| 291 |
-
max_length: Maximum sequence length
|
| 292 |
-
model_type: Either "altered" or "base" to determine how to process inputs
|
| 293 |
-
|
| 294 |
-
Returns:
|
| 295 |
-
prediction: 0 or 1
|
| 296 |
-
probability: Confidence score
|
| 297 |
-
rationale_scores: Token-level rationale scores
|
| 298 |
-
"""
|
| 299 |
model.eval()
|
| 300 |
|
| 301 |
# Tokenize inputs
|
|
@@ -321,79 +598,99 @@ def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale
|
|
| 321 |
add_input_ids = inputs_rationale['input_ids'].to(device)
|
| 322 |
add_attention_mask = inputs_rationale['attention_mask'].to(device)
|
| 323 |
|
| 324 |
-
# Inference
|
| 325 |
-
|
| 326 |
-
|
| 327 |
logits = model(
|
| 328 |
input_ids,
|
| 329 |
attention_mask,
|
| 330 |
add_input_ids,
|
| 331 |
add_attention_mask
|
| 332 |
)
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
|
|
|
|
|
|
| 338 |
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
input_ids,
|
| 350 |
-
attention_mask,
|
| 351 |
-
add_input_ids,
|
| 352 |
-
add_attention_mask
|
| 353 |
-
)
|
| 354 |
|
| 355 |
-
# Get probabilities
|
| 356 |
-
probs = torch.softmax(logits, dim=1)
|
| 357 |
prediction = logits.argmax(dim=1).item()
|
| 358 |
confidence = probs[0, prediction].item()
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
-
|
| 361 |
'prediction': prediction,
|
| 362 |
'confidence': confidence,
|
| 363 |
'probabilities': probs[0].cpu().numpy(),
|
| 364 |
-
'rationale_scores': rationale_probs[0].cpu().numpy(),
|
| 365 |
'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
|
| 366 |
}
|
| 367 |
-
|
| 368 |
-
def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 369 |
-
"""
|
| 370 |
-
Predict hate speech for text read from a file
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
| 391 |
predictions = []
|
|
|
|
| 392 |
cpu_percent_list = []
|
| 393 |
memory_percent_list = []
|
| 394 |
|
| 395 |
process = psutil.Process(os.getpid())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
start_time = time()
|
|
|
|
| 397 |
for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
|
| 398 |
result = predict_text(
|
| 399 |
text=text,
|
|
@@ -405,27 +702,45 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
|
|
| 405 |
max_length=config.get('max_length', 128),
|
| 406 |
model_type=model_type
|
| 407 |
)
|
|
|
|
| 408 |
predictions.append(result['prediction'])
|
| 409 |
-
|
|
|
|
|
|
|
| 410 |
if idx % 10 == 0 or idx == len(text_list) - 1:
|
| 411 |
cpu_percent_list.append(process.cpu_percent())
|
| 412 |
memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
end_time = time()
|
| 415 |
runtime = end_time - start_time
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
f1 = f1_score(true_label, predictions, zero_division=0)
|
| 418 |
accuracy = accuracy_score(true_label, predictions)
|
| 419 |
precision = precision_score(true_label, predictions, zero_division=0)
|
| 420 |
recall = recall_score(true_label, predictions, zero_division=0)
|
| 421 |
cm = confusion_matrix(true_label, predictions).tolist()
|
| 422 |
-
|
| 423 |
avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
|
| 424 |
-
avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
|
| 425 |
peak_memory = max(memory_percent_list) if memory_percent_list else 0
|
| 426 |
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 427 |
|
| 428 |
return {
|
|
|
|
| 429 |
'f1_score': f1,
|
| 430 |
'accuracy': accuracy,
|
| 431 |
'precision': precision,
|
|
@@ -435,25 +750,14 @@ def predict_hatespeech_from_file(text_list, rationale_list, true_label, model, t
|
|
| 435 |
'memory_usage': avg_memory,
|
| 436 |
'peak_cpu_usage': peak_cpu,
|
| 437 |
'peak_memory_usage': peak_memory,
|
| 438 |
-
'runtime': runtime
|
|
|
|
| 439 |
}
|
| 440 |
|
| 441 |
|
| 442 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 443 |
"""
|
| 444 |
Predict hate speech for given text
|
| 445 |
-
|
| 446 |
-
Args:
|
| 447 |
-
text: Input text to classify
|
| 448 |
-
rationale: Optional rationale text
|
| 449 |
-
model: Loaded model
|
| 450 |
-
tokenizer_hatebert: HateBERT tokenizer
|
| 451 |
-
tokenizer_rationale: Rationale tokenizer
|
| 452 |
-
config: Model configuration
|
| 453 |
-
device: Device to run on
|
| 454 |
-
|
| 455 |
-
Returns:
|
| 456 |
-
Dictionary with prediction results
|
| 457 |
"""
|
| 458 |
# Get prediction
|
| 459 |
result = predict_text(
|
|
@@ -468,88 +772,3 @@ def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rat
|
|
| 468 |
)
|
| 469 |
|
| 470 |
return result
|
| 471 |
-
|
| 472 |
-
def predict_hatespeech_from_file_mock():
|
| 473 |
-
"""
|
| 474 |
-
Mock function for predict_hatespeech_from_file that returns hardcoded data for testing
|
| 475 |
-
|
| 476 |
-
Args:
|
| 477 |
-
text_list: List of input texts to classify (not used in mock)
|
| 478 |
-
rationale_list: List of rationale/explanation texts (not used in mock)
|
| 479 |
-
true_label: True label for evaluation (not used in mock)
|
| 480 |
-
model: Loaded model (not used in mock)
|
| 481 |
-
tokenizer_hatebert: HateBERT tokenizer (not used in mock)
|
| 482 |
-
tokenizer_rationale: Rationale tokenizer (not used in mock)
|
| 483 |
-
config: Model configuration (not used in mock)
|
| 484 |
-
device: Device to run on (not used in mock)
|
| 485 |
-
Returns:
|
| 486 |
-
Dictionary with hardcoded metrics for testing
|
| 487 |
-
"""
|
| 488 |
-
# Hardcoded predictions matching the number of samples
|
| 489 |
-
predictions = [0, 1, 1, 0, 1, 0, 0, 1, 1, 0]
|
| 490 |
-
true_labels = [0, 1, 1, 0, 0, 0, 1, 1, 1, 0]
|
| 491 |
-
|
| 492 |
-
# Hardcoded resource usage metrics
|
| 493 |
-
cpu_percent_list = [25.3, 28.1, 26.5, 27.2, 26.8, 27.9, 25.5, 28.3, 26.2, 27.1]
|
| 494 |
-
memory_percent_list = [145.3, 152.1, 148.5, 151.2, 149.8, 153.2, 146.5, 154.3, 150.2, 152.1]
|
| 495 |
-
|
| 496 |
-
f1 = f1_score(true_labels, predictions, zero_division=0)
|
| 497 |
-
accuracy = accuracy_score(true_labels, predictions)
|
| 498 |
-
precision = precision_score(true_labels, predictions, zero_division=0)
|
| 499 |
-
recall = recall_score(true_labels, predictions, zero_division=0)
|
| 500 |
-
cm = confusion_matrix(true_labels, predictions).tolist()
|
| 501 |
-
|
| 502 |
-
avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
|
| 503 |
-
avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
|
| 504 |
-
peak_memory = max(memory_percent_list) if memory_percent_list else 0
|
| 505 |
-
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 506 |
-
|
| 507 |
-
# Hardcoded runtime
|
| 508 |
-
runtime = 12.543
|
| 509 |
-
|
| 510 |
-
return {
|
| 511 |
-
'f1_score': f1,
|
| 512 |
-
'accuracy': accuracy,
|
| 513 |
-
'precision': precision,
|
| 514 |
-
'recall': recall,
|
| 515 |
-
'confusion_matrix': cm,
|
| 516 |
-
'cpu_usage': avg_cpu,
|
| 517 |
-
'memory_usage': avg_memory,
|
| 518 |
-
'peak_cpu_usage': peak_cpu,
|
| 519 |
-
'peak_memory_usage': peak_memory,
|
| 520 |
-
'runtime': runtime,
|
| 521 |
-
'predictions': predictions # Added for visibility
|
| 522 |
-
}
|
| 523 |
-
|
| 524 |
-
def predict_text_mock(text, max_length=128):
|
| 525 |
-
import numpy as np
|
| 526 |
-
|
| 527 |
-
# Simple whitespace tokenization for mock output
|
| 528 |
-
raw_tokens = (text or "").split()
|
| 529 |
-
mock_tokens = raw_tokens[:max_length]
|
| 530 |
-
|
| 531 |
-
# Build a simple attention mask (1 for tokens)
|
| 532 |
-
attention_mask = [1] * len(mock_tokens)
|
| 533 |
-
|
| 534 |
-
# Generate random rationale scores matching token count
|
| 535 |
-
mock_rationale_scores = np.random.rand(len(mock_tokens)).astype(np.float32)
|
| 536 |
-
|
| 537 |
-
# Randomized probabilities [class_0, class_1]
|
| 538 |
-
# Class 0 = not hate speech, Class 1 = hate speech
|
| 539 |
-
mock_probabilities = np.random.rand(2).astype(np.float32)
|
| 540 |
-
mock_probabilities = mock_probabilities / mock_probabilities.sum()
|
| 541 |
-
|
| 542 |
-
# Prediction (argmax of probabilities)
|
| 543 |
-
mock_prediction = int(np.argmax(mock_probabilities)) # Class 1: hate speech
|
| 544 |
-
|
| 545 |
-
# Confidence score
|
| 546 |
-
mock_confidence = float(np.max(mock_probabilities))
|
| 547 |
-
|
| 548 |
-
return {
|
| 549 |
-
'prediction': mock_prediction,
|
| 550 |
-
'confidence': mock_confidence,
|
| 551 |
-
'probabilities': mock_probabilities,
|
| 552 |
-
'rationale_scores': mock_rationale_scores,
|
| 553 |
-
'tokens': mock_tokens,
|
| 554 |
-
'attention_mask': attention_mask
|
| 555 |
-
}
|
|
|
|
| 9 |
from time import time
|
| 10 |
import psutil
|
| 11 |
import os
|
| 12 |
+
import numpy as np
|
| 13 |
+
import requests
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
API_BASE_URL = "https://api.cloudflare.com/client/v4/accounts/8fcfcf97aa4c166eee626b79a67f902d/ai/run/"
|
| 17 |
+
HEADERS = {"Authorization": "Bearer 2Qb-uZ6M8yzkKZmGmcxZGRveNvk3YXBJwhlQyOfP"}
|
| 18 |
+
MODEL_NAME = "@cf/mistralai/mistral-small-3.1-24b-instruct"
|
| 19 |
+
|
| 20 |
+
def create_prompt(text):
|
| 21 |
+
return f"""
|
| 22 |
+
You are a content moderation assistant. Identify the list of [rationales] words or phrases from the text that make it hateful,
|
| 23 |
+
list of [derogatory language], and [list of cuss words] and [hate_classification] such as "hateful" or "non-hateful".
|
| 24 |
+
If there are none, respond exactly with [non-hateful] only.
|
| 25 |
+
Output should be in JSON format only. Text: {text}.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def run_mistral_model(model, inputs):
|
| 29 |
+
payload = {"messages": inputs}
|
| 30 |
+
response = requests.post(f"{API_BASE_URL}{model}", headers=HEADERS, json=payload)
|
| 31 |
+
response.raise_for_status()
|
| 32 |
+
return response.json()
|
| 33 |
+
|
| 34 |
+
def flatten_json_string(json_string):
|
| 35 |
+
try:
|
| 36 |
+
obj = json.loads(json_string)
|
| 37 |
+
return json.dumps(obj, separators=(",", ":"))
|
| 38 |
+
except:
|
| 39 |
+
return json_string
|
| 40 |
+
|
| 41 |
+
def get_rationale_from_mistral(text, retries=10):
|
| 42 |
+
"""
|
| 43 |
+
Sends text to Mistral AI and returns a cleaned JSON rationale string.
|
| 44 |
+
Retries if the model returns invalid output or starts with "I cannot".
|
| 45 |
+
"""
|
| 46 |
+
for attempt in range(retries):
|
| 47 |
+
try:
|
| 48 |
+
inputs = [{"role": "user", "content": create_prompt(text)}]
|
| 49 |
+
output = run_mistral_model(MODEL_NAME, inputs)
|
| 50 |
+
|
| 51 |
+
result = output.get("result", {})
|
| 52 |
+
response_text = result.get("response", "").strip()
|
| 53 |
+
|
| 54 |
+
if not response_text or response_text.startswith("I cannot"):
|
| 55 |
+
print(f"⚠️ Model returned 'I cannot...' — retrying ({attempt+1}/{retries})")
|
| 56 |
+
continue # retry
|
| 57 |
+
|
| 58 |
+
# Flatten JSON response and clean
|
| 59 |
+
cleaned_rationale = flatten_json_string(response_text).replace("\n", " ").strip()
|
| 60 |
+
return cleaned_rationale
|
| 61 |
+
|
| 62 |
+
except requests.exceptions.HTTPError as e:
|
| 63 |
+
print(f"⚠️ HTTP Error on attempt {attempt+1}: {e}")
|
| 64 |
+
# If resource exhausted or rate limited, raise
|
| 65 |
+
if "RESOURCE_EXHAUSTED" in str(e) or e.response.status_code == 429:
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
# Fallback if all retries fail
|
| 69 |
+
return "non-hateful"
|
| 70 |
+
|
| 71 |
+
def preprocess_rationale_mistral(raw_rationale):
|
| 72 |
+
"""
|
| 73 |
+
Cleans and standardizes rationale text from Mistral AI.
|
| 74 |
+
- Removes ```json fences
|
| 75 |
+
- Fixes escaped quotes
|
| 76 |
+
- Extracts JSON content
|
| 77 |
+
- Returns 'non-hateful' if all rationale lists are empty
|
| 78 |
+
- Otherwise returns a clean, one-line JSON of rationales
|
| 79 |
+
"""
|
| 80 |
+
try:
|
| 81 |
+
x = str(raw_rationale).strip()
|
| 82 |
+
|
| 83 |
+
# Remove ```json fences
|
| 84 |
+
if x.startswith("```"):
|
| 85 |
+
x = x.replace("```json", "").replace("```", "").strip()
|
| 86 |
+
|
| 87 |
+
# Fix double quotes
|
| 88 |
+
x = x.replace('""', '"')
|
| 89 |
+
|
| 90 |
+
# Extract JSON object
|
| 91 |
+
start = x.find("{")
|
| 92 |
+
end = x.rfind("}") + 1
|
| 93 |
+
if start == -1 or end == -1:
|
| 94 |
+
return x.lower() # fallback
|
| 95 |
+
|
| 96 |
+
j = json.loads(x[start:end])
|
| 97 |
+
|
| 98 |
+
keys = ["rationales", "derogatory_language", "cuss_words"]
|
| 99 |
+
|
| 100 |
+
# If all lists exist and are empty → non-hateful
|
| 101 |
+
if all(k in j and isinstance(j[k], list) and len(j[k]) == 0 for k in keys):
|
| 102 |
+
return "non-hateful"
|
| 103 |
+
|
| 104 |
+
# Otherwise, return clean JSON of relevant keys
|
| 105 |
+
cleaned = {k: j.get(k, []) for k in keys}
|
| 106 |
+
return json.dumps(cleaned).lower()
|
| 107 |
+
|
| 108 |
+
except Exception:
|
| 109 |
+
return str(raw_rationale).lower()
|
| 110 |
|
| 111 |
# Model Architecture Classes
|
| 112 |
class TemporalCNN(nn.Module):
|
| 113 |
+
"""
|
| 114 |
+
Temporal CNN applied across the sequence (time) dimension.
|
| 115 |
+
Input: sequence_embeddings (B, L, H), attention_mask (B, L)
|
| 116 |
+
Output: pooled vector (B, output_dim) where output_dim = num_filters * len(kernel_sizes) * 2
|
| 117 |
+
(we concatenate max-pooled and mean-pooled features for each kernel size)
|
| 118 |
+
"""
|
| 119 |
+
def __init__(self, input_dim=768, num_filters=256, kernel_sizes=(2, 3, 4), dropout=0.3):
|
| 120 |
super().__init__()
|
| 121 |
+
self.input_dim = input_dim
|
| 122 |
+
self.num_filters = num_filters
|
| 123 |
self.kernel_sizes = kernel_sizes
|
| 124 |
+
|
| 125 |
+
# Convs expect (B, C_in, L) where C_in = input_dim
|
| 126 |
self.convs = nn.ModuleList([
|
| 127 |
+
nn.Conv1d(in_channels=input_dim, out_channels=num_filters, kernel_size=k, padding=k // 2)
|
| 128 |
+
for k in kernel_sizes
|
| 129 |
])
|
| 130 |
self.dropout = nn.Dropout(dropout)
|
| 131 |
+
|
| 132 |
+
def forward(self, sequence_embeddings, attention_mask=None):
|
| 133 |
+
"""
|
| 134 |
+
sequence_embeddings: (B, L, H)
|
| 135 |
+
attention_mask: (B, L) with 1 for valid tokens, 0 for padding
|
| 136 |
+
returns: (B, num_filters * len(kernel_sizes) * 2) # max + mean pooled per conv
|
| 137 |
+
"""
|
| 138 |
+
# transpose to (B, H, L)
|
| 139 |
+
x = sequence_embeddings.transpose(1, 2).contiguous() # (B, H, L)
|
| 140 |
+
|
| 141 |
+
pooled_outputs = []
|
| 142 |
+
for conv in self.convs:
|
| 143 |
+
conv_out = conv(x) # (B, num_filters, L_out)
|
| 144 |
+
conv_out = F.relu(conv_out)
|
| 145 |
+
L_out = conv_out.size(2)
|
| 146 |
+
|
| 147 |
+
if attention_mask is not None:
|
| 148 |
+
# resize mask to match L_out
|
| 149 |
+
mask = attention_mask.float()
|
| 150 |
+
if mask.size(1) != L_out:
|
| 151 |
+
mask = F.interpolate(mask.unsqueeze(1), size=L_out, mode='nearest').squeeze(1)
|
| 152 |
+
mask = mask.unsqueeze(1).to(conv_out.device) # (B,1,L_out)
|
| 153 |
+
|
| 154 |
+
# max pool with masking
|
| 155 |
+
neg_inf = torch.finfo(conv_out.dtype).min / 2
|
| 156 |
+
max_masked = torch.where(mask.bool(), conv_out, neg_inf * torch.ones_like(conv_out))
|
| 157 |
+
max_pooled = torch.max(max_masked, dim=2)[0] # (B, num_filters)
|
| 158 |
+
|
| 159 |
+
# mean pool with masking
|
| 160 |
+
sum_masked = (conv_out * mask).sum(dim=2) # (B, num_filters)
|
| 161 |
+
denom = mask.sum(dim=2).clamp_min(1e-6) # (B,1)
|
| 162 |
+
mean_pooled = sum_masked / denom # (B, num_filters)
|
| 163 |
+
else:
|
| 164 |
+
max_pooled = torch.max(conv_out, dim=2)[0]
|
| 165 |
+
mean_pooled = conv_out.mean(dim=2)
|
| 166 |
+
|
| 167 |
+
pooled_outputs.append(max_pooled)
|
| 168 |
+
pooled_outputs.append(mean_pooled)
|
| 169 |
+
|
| 170 |
+
out = torch.cat(pooled_outputs, dim=1) # (B, num_filters * len(kernel_sizes) * 2)
|
| 171 |
+
out = self.dropout(out)
|
| 172 |
+
return out
|
| 173 |
+
|
| 174 |
|
| 175 |
class MultiScaleAttentionCNN(nn.Module):
|
| 176 |
+
def __init__(self, hidden_size=768, num_filters=64, kernel_sizes=(2, 3, 4, 5, 6, 7), dropout=0.3):
|
| 177 |
+
super().__init__()
|
| 178 |
+
|
| 179 |
+
self.hidden_size = hidden_size
|
| 180 |
+
self.kernel_sizes = kernel_sizes
|
| 181 |
+
|
| 182 |
+
self.convs = nn.ModuleList()
|
| 183 |
+
self.pads = nn.ModuleList()
|
| 184 |
+
|
| 185 |
+
for k in self.kernel_sizes:
|
| 186 |
+
pad_left = (k - 1) // 2
|
| 187 |
+
pad_right = k - 1 - pad_left
|
| 188 |
+
self.pads.append(nn.ConstantPad1d((pad_left, pad_right), 0.0))
|
| 189 |
+
self.convs.append(nn.Conv1d(hidden_size, num_filters, kernel_size=k, padding=0))
|
| 190 |
+
|
| 191 |
+
self.attn = nn.ModuleList([nn.Linear(num_filters, 1) for _ in self.kernel_sizes])
|
| 192 |
+
self.output_size = num_filters * len(self.kernel_sizes)
|
| 193 |
+
self.dropout = nn.Dropout(dropout)
|
| 194 |
+
|
| 195 |
+
def forward(self, hidden_states, mask):
|
| 196 |
+
"""
|
| 197 |
+
hidden_states: (B, L, H)
|
| 198 |
+
mask: (B, L)
|
| 199 |
+
"""
|
| 200 |
+
x = hidden_states.transpose(1, 2) # (B, H, L)
|
| 201 |
+
attn_mask = mask.unsqueeze(1).float()
|
| 202 |
+
|
| 203 |
+
conv_outs = []
|
| 204 |
+
|
| 205 |
+
for pad, conv, att in zip(self.pads, self.convs, self.attn):
|
| 206 |
+
padded = pad(x) # (B, H, L)
|
| 207 |
+
c = conv(padded) # (B, F, L)
|
| 208 |
+
c = F.relu(c)
|
| 209 |
+
c = c * attn_mask
|
| 210 |
+
|
| 211 |
+
c_t = c.transpose(1, 2) # (B, L, F)
|
| 212 |
+
w = att(c_t) # (B, L, 1)
|
| 213 |
+
w = w.masked_fill(mask.unsqueeze(-1) == 0, -1e9)
|
| 214 |
+
w = F.softmax(w, dim=1)
|
| 215 |
+
|
| 216 |
+
pooled = (c_t * w).sum(dim=1) # (B, F)
|
| 217 |
+
conv_outs.append(pooled)
|
| 218 |
+
|
| 219 |
+
out = torch.cat(conv_outs, dim=1) # (B, F * K)
|
| 220 |
+
return self.dropout(out)
|
| 221 |
+
|
| 222 |
|
| 223 |
class ProjectionMLP(nn.Module):
|
| 224 |
+
def __init__(self, input_size, hidden_size=256, num_labels=2):
|
| 225 |
super().__init__()
|
| 226 |
self.layers = nn.Sequential(
|
| 227 |
nn.Linear(input_size, hidden_size),
|
| 228 |
nn.ReLU(),
|
| 229 |
nn.Linear(hidden_size, num_labels)
|
| 230 |
)
|
| 231 |
+
|
| 232 |
def forward(self, x):
|
| 233 |
return self.layers(x)
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
class ConcatModelWithRationale(nn.Module):
|
| 237 |
+
def __init__(self,
|
| 238 |
+
hatebert_model,
|
| 239 |
+
additional_model,
|
| 240 |
+
projection_mlp,
|
| 241 |
+
hidden_size=768,
|
| 242 |
+
gumbel_temp=0.5,
|
| 243 |
+
freeze_additional_model=True,
|
| 244 |
+
cnn_num_filters=64,
|
| 245 |
+
cnn_kernel_sizes=(2, 3, 4, 5, 6, 7),
|
| 246 |
+
cnn_dropout=0.0):
|
| 247 |
super().__init__()
|
| 248 |
self.hatebert_model = hatebert_model
|
| 249 |
self.additional_model = additional_model
|
| 250 |
self.projection_mlp = projection_mlp
|
| 251 |
+
self.gumbel_temp = gumbel_temp
|
| 252 |
+
self.hidden_size = hidden_size
|
| 253 |
+
|
| 254 |
if freeze_additional_model:
|
| 255 |
for param in self.additional_model.parameters():
|
| 256 |
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
# selector head (per-token logits)
|
| 259 |
+
self.selector = nn.Linear(hidden_size, 1)
|
| 260 |
+
|
| 261 |
+
# Temporal CNN over HateBERT embeddings (main text)
|
| 262 |
+
self.temporal_cnn = TemporalCNN(input_dim=hidden_size,
|
| 263 |
+
num_filters=cnn_num_filters,
|
| 264 |
+
kernel_sizes=cnn_kernel_sizes,
|
| 265 |
+
dropout=cnn_dropout)
|
| 266 |
+
self.temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
|
| 267 |
+
|
| 268 |
+
# MultiScaleAttentionCNN over rationale embeddings (frozen BERT)
|
| 269 |
+
self.msa_cnn = MultiScaleAttentionCNN(hidden_size=hidden_size,
|
| 270 |
+
num_filters=cnn_num_filters,
|
| 271 |
+
kernel_sizes=cnn_kernel_sizes,
|
| 272 |
+
dropout=cnn_dropout)
|
| 273 |
+
self.msa_out_dim = self.msa_cnn.output_size
|
| 274 |
+
|
| 275 |
+
def gumbel_sigmoid_sample(self, logits):
|
| 276 |
+
noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9)
|
| 277 |
+
y = logits + noise
|
| 278 |
+
return torch.sigmoid(y / self.gumbel_temp)
|
| 279 |
+
|
| 280 |
+
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask, return_attentions=False):
|
| 281 |
+
# Main text through HateBERT
|
| 282 |
+
hatebert_out = self.hatebert_model(input_ids=input_ids,
|
| 283 |
+
attention_mask=attention_mask,
|
| 284 |
+
output_attentions=return_attentions,
|
| 285 |
+
return_dict=True)
|
| 286 |
+
hatebert_emb = hatebert_out.last_hidden_state # (B, L, H)
|
| 287 |
+
cls_emb = hatebert_emb[:, 0, :] # (B, H)
|
| 288 |
|
| 289 |
+
# Rationale text through frozen BERT
|
| 290 |
+
with torch.no_grad():
|
| 291 |
+
add_out = self.additional_model(input_ids=additional_input_ids,
|
| 292 |
+
attention_mask=additional_attention_mask,
|
| 293 |
+
return_dict=True)
|
| 294 |
+
rationale_emb = add_out.last_hidden_state # (B, L, H)
|
| 295 |
+
|
| 296 |
+
# selector logits & Gumbel-Sigmoid sampling on HateBERT
|
| 297 |
+
selector_logits = self.selector(hatebert_emb).squeeze(-1) # (B, L)
|
| 298 |
+
rationale_probs = self.gumbel_sigmoid_sample(selector_logits) # (B, L)
|
| 299 |
+
rationale_probs = rationale_probs * attention_mask.float().to(rationale_probs.device)
|
| 300 |
+
|
| 301 |
+
# pooled rationale summary
|
| 302 |
+
masked_hidden = hatebert_emb * rationale_probs.unsqueeze(-1)
|
| 303 |
+
denom = rationale_probs.sum(1).unsqueeze(-1).clamp_min(1e-6)
|
| 304 |
+
pooled_rationale = masked_hidden.sum(1) / denom # (B, H)
|
| 305 |
+
|
| 306 |
+
# CNN branches
|
| 307 |
+
temporal_features = self.temporal_cnn(hatebert_emb, attention_mask) # (B, temporal_out_dim)
|
| 308 |
+
rationale_features = self.msa_cnn(rationale_emb, additional_attention_mask) # (B, msa_out_dim)
|
| 309 |
+
|
| 310 |
+
# concat CLS + CNN features + pooled rationale
|
| 311 |
+
concat_emb = torch.cat((cls_emb, temporal_features, rationale_features, pooled_rationale), dim=1)
|
| 312 |
+
|
| 313 |
+
logits = self.projection_mlp(concat_emb)
|
| 314 |
+
|
| 315 |
+
attns = hatebert_out.attentions if (return_attentions and hasattr(hatebert_out, "attentions")) else None
|
| 316 |
+
return logits, rationale_probs, selector_logits, attns
|
| 317 |
|
|
|
|
|
|
|
| 318 |
|
| 319 |
+
class BaseShield(nn.Module):
|
| 320 |
+
"""
|
| 321 |
+
Simple base model that concatenates HateBERT and rationale BERT CLS embeddings
|
| 322 |
+
and projects to label logits via a small MLP.
|
| 323 |
+
"""
|
| 324 |
+
def __init__(self, hatebert_model, additional_model, projection_mlp, device='cpu', freeze_additional_model=True):
|
| 325 |
super().__init__()
|
| 326 |
self.hatebert_model = hatebert_model
|
| 327 |
self.additional_model = additional_model
|
|
|
|
|
|
|
|
|
|
| 328 |
self.projection_mlp = projection_mlp
|
| 329 |
+
self.device = device
|
| 330 |
|
| 331 |
if freeze_additional_model:
|
| 332 |
+
for param in self.additional_model.parameters():
|
| 333 |
+
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
def forward(self, input_ids, attention_mask, additional_input_ids, additional_attention_mask):
|
| 336 |
+
hatebert_outputs = self.hatebert_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
|
| 337 |
+
hatebert_embeddings = hatebert_outputs.last_hidden_state[:, 0, :]
|
| 338 |
+
|
| 339 |
+
additional_outputs = self.additional_model(input_ids=additional_input_ids, attention_mask=additional_attention_mask, return_dict=True)
|
| 340 |
+
additional_embeddings = additional_outputs.last_hidden_state[:, 0, :]
|
| 341 |
+
|
| 342 |
+
concatenated_embeddings = torch.cat((hatebert_embeddings, additional_embeddings), dim=1)
|
| 343 |
+
logits = self.projection_mlp(concatenated_embeddings)
|
| 344 |
+
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
def load_model_from_hf(model_type="altered"):
|
| 347 |
"""
|
|
|
|
| 353 |
|
| 354 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 355 |
repo_id = "seffyehl/BetterShield"
|
|
|
|
| 356 |
|
| 357 |
# Choose model and config files based on model_type
|
| 358 |
if model_type.lower() == "altered":
|
| 359 |
model_filename = "AlteredShield.pth"
|
| 360 |
config_filename = "alter_config.json"
|
| 361 |
elif model_type.lower() == "base":
|
| 362 |
+
model_filename = "BaselineShield.pth"
|
| 363 |
config_filename = "base_config.json"
|
| 364 |
else:
|
| 365 |
raise ValueError(f"model_type must be 'altered' or 'base', got '{model_type}'")
|
|
|
|
| 367 |
# Download files
|
| 368 |
model_path = hf_hub_download(
|
| 369 |
repo_id=repo_id,
|
|
|
|
| 370 |
filename=model_filename
|
| 371 |
)
|
| 372 |
|
| 373 |
config_path = hf_hub_download(
|
| 374 |
repo_id=repo_id,
|
| 375 |
filename=config_filename,
|
|
|
|
| 376 |
)
|
| 377 |
|
| 378 |
# Load config
|
| 379 |
with open(config_path, 'r') as f:
|
| 380 |
config = json.load(f)
|
| 381 |
|
| 382 |
+
# Load checkpoint with proper handling for numpy dtypes (PyTorch 2.6+ compatibility)
|
| 383 |
+
try:
|
| 384 |
+
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 385 |
+
except TypeError:
|
| 386 |
+
# Fallback for older PyTorch versions
|
| 387 |
+
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
| 388 |
|
| 389 |
# Handle nested config structure (base model uses model_config, altered uses flat structure)
|
| 390 |
if 'model_config' in config:
|
|
|
|
| 401 |
tokenizer_hatebert = AutoTokenizer.from_pretrained(model_config['hatebert_model'])
|
| 402 |
tokenizer_rationale = AutoTokenizer.from_pretrained(model_config['rationale_model'])
|
| 403 |
|
| 404 |
+
# Rebuild architecture based on model type using training_config values when available
|
| 405 |
H = hatebert_model.config.hidden_size
|
| 406 |
max_length = training_config.get('max_length', 128)
|
| 407 |
+
|
| 408 |
+
# common params from training config (use None to allow inference from checkpoint)
|
| 409 |
+
adapter_dim = training_config.get('adapter_dim', training_config.get('adapter_size', None))
|
| 410 |
+
cnn_num_filters = training_config.get('cnn_num_filters', None)
|
| 411 |
+
cnn_kernel_sizes = training_config.get('cnn_kernel_sizes', None)
|
| 412 |
+
cnn_dropout = training_config.get('cnn_dropout', 0.3)
|
| 413 |
+
freeze_rationale = training_config.get('freeze_additional_model', True)
|
| 414 |
+
num_labels = training_config.get('num_labels', 2)
|
| 415 |
+
|
| 416 |
+
# Infer architecture params from checkpoint state_dict when possible to match saved weights
|
| 417 |
+
state_dict = None
|
| 418 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 419 |
+
state_dict = checkpoint['model_state_dict']
|
| 420 |
+
elif isinstance(checkpoint, dict):
|
| 421 |
+
# sometimes checkpoint is a raw state_dict saved as dict
|
| 422 |
+
state_dict = checkpoint
|
| 423 |
+
|
| 424 |
+
if state_dict is not None:
|
| 425 |
+
# infer temporal convs count and filters if present
|
| 426 |
+
temporal_keys = [k for k in state_dict.keys() if k.startswith('temporal_cnn.convs.') and k.endswith('.weight')]
|
| 427 |
+
if temporal_keys:
|
| 428 |
+
try:
|
| 429 |
+
sample = state_dict[temporal_keys[0]]
|
| 430 |
+
inferred_num_filters = sample.shape[0]
|
| 431 |
+
inferred_kernel_count = len(temporal_keys)
|
| 432 |
+
if cnn_num_filters is None:
|
| 433 |
+
cnn_num_filters = int(inferred_num_filters)
|
| 434 |
+
if cnn_kernel_sizes is None:
|
| 435 |
+
cnn_kernel_sizes = training_config.get('cnn_kernel_sizes', (2,3,4,5,6,7))
|
| 436 |
+
except Exception:
|
| 437 |
+
pass
|
| 438 |
+
|
| 439 |
+
# infer projection dims/adapt size
|
| 440 |
+
proj_w_key = None
|
| 441 |
+
for key in ('projection_mlp.layers.0.weight', 'projection_mlp.0.weight', 'projection_mlp.layers.0.weight_orig'):
|
| 442 |
+
if key in state_dict:
|
| 443 |
+
proj_w_key = key
|
| 444 |
+
break
|
| 445 |
+
if proj_w_key is not None:
|
| 446 |
+
try:
|
| 447 |
+
proj_w = state_dict[proj_w_key]
|
| 448 |
+
inferred_adapter_dim = proj_w.shape[0]
|
| 449 |
+
if adapter_dim is None:
|
| 450 |
+
adapter_dim = int(inferred_adapter_dim)
|
| 451 |
+
except Exception:
|
| 452 |
+
pass
|
| 453 |
+
|
| 454 |
+
# sensible defaults when neither config nor checkpoint provided values
|
| 455 |
+
if cnn_num_filters is None:
|
| 456 |
+
cnn_num_filters = 64 # Changed from 128 to match typical training configs
|
| 457 |
+
if cnn_kernel_sizes is None:
|
| 458 |
+
cnn_kernel_sizes = (2, 3, 4, 5, 6, 7)
|
| 459 |
+
if adapter_dim is None:
|
| 460 |
+
adapter_dim = 128
|
| 461 |
+
|
| 462 |
if model_type.lower() == "base":
|
| 463 |
+
proj_input_dim = H * 2
|
| 464 |
+
projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
model = BaseShield(
|
| 466 |
hatebert_model=hatebert_model,
|
| 467 |
additional_model=rationale_model,
|
| 468 |
projection_mlp=projection_mlp,
|
| 469 |
+
freeze_additional_model=freeze_rationale,
|
| 470 |
device=device
|
| 471 |
).to(device)
|
| 472 |
else:
|
| 473 |
+
# For altered model, let ConcatModelWithRationale initialize its own CNN modules
|
| 474 |
+
# The CNN modules are created inside __init__, so we just need to create the model
|
| 475 |
+
# and then load the state dict
|
| 476 |
+
|
| 477 |
+
# First, create a dummy projection_mlp - we'll replace it after calculating dims
|
| 478 |
+
# Actually, we need to calculate dims first to create the correct projection_mlp
|
| 479 |
+
|
| 480 |
+
# Calculate dimensions based on inferred parameters
|
| 481 |
+
temporal_out_dim = cnn_num_filters * len(cnn_kernel_sizes) * 2
|
| 482 |
+
msa_out_dim = cnn_num_filters * len(cnn_kernel_sizes)
|
| 483 |
+
proj_input_dim = H + temporal_out_dim + msa_out_dim + H
|
| 484 |
+
|
| 485 |
+
projection_mlp = ProjectionMLP(input_size=proj_input_dim, hidden_size=adapter_dim, num_labels=num_labels)
|
| 486 |
+
|
| 487 |
+
model = ConcatModelWithRationale(
|
| 488 |
+
hatebert_model=hatebert_model,
|
| 489 |
+
additional_model=rationale_model,
|
| 490 |
+
projection_mlp=projection_mlp,
|
| 491 |
+
hidden_size=H,
|
| 492 |
+
freeze_additional_model=freeze_rationale,
|
| 493 |
+
cnn_num_filters=cnn_num_filters,
|
| 494 |
+
cnn_kernel_sizes=cnn_kernel_sizes,
|
| 495 |
+
cnn_dropout=cnn_dropout
|
| 496 |
+
).to(device)
|
| 497 |
|
| 498 |
+
# Load state dict with strict checking and error reporting
|
| 499 |
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 500 |
+
state_dict_to_load = checkpoint['model_state_dict']
|
| 501 |
+
else:
|
| 502 |
+
state_dict_to_load = checkpoint
|
| 503 |
+
|
| 504 |
+
# Check for missing and unexpected keys
|
| 505 |
+
model_keys = set(model.state_dict().keys())
|
| 506 |
+
checkpoint_keys = set(state_dict_to_load.keys())
|
| 507 |
+
|
| 508 |
+
missing_keys = model_keys - checkpoint_keys
|
| 509 |
+
unexpected_keys = checkpoint_keys - model_keys
|
| 510 |
+
|
| 511 |
+
if missing_keys:
|
| 512 |
+
print(f"WARNING: Missing keys in checkpoint: {missing_keys}")
|
| 513 |
+
if unexpected_keys:
|
| 514 |
+
print(f"WARNING: Unexpected keys in checkpoint: {unexpected_keys}")
|
| 515 |
+
|
| 516 |
+
# Load with strict=False to handle any minor mismatches, but log warnings
|
| 517 |
+
incompatible_keys = model.load_state_dict(state_dict_to_load, strict=True)
|
| 518 |
+
|
| 519 |
+
if incompatible_keys.missing_keys:
|
| 520 |
+
print(f"Missing keys after load: {incompatible_keys.missing_keys}")
|
| 521 |
+
if incompatible_keys.unexpected_keys:
|
| 522 |
+
print(f"Unexpected keys after load: {incompatible_keys.unexpected_keys}")
|
| 523 |
+
|
| 524 |
+
if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
|
| 525 |
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
| 526 |
print(f"Dataset: {checkpoint.get('dataset', 'unknown')}, Seed: {checkpoint.get('seed', 'unknown')}")
|
| 527 |
+
|
| 528 |
+
# CRITICAL: Set to eval mode and ensure no gradient computation
|
| 529 |
model.eval()
|
| 530 |
+
|
| 531 |
+
# Disable dropout explicitly by setting training mode to False for all modules
|
| 532 |
+
for module in model.modules():
|
| 533 |
+
if isinstance(module, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)):
|
| 534 |
+
module.p = 0 # Set dropout probability to 0
|
| 535 |
+
|
| 536 |
model = model.to(device)
|
| 537 |
|
| 538 |
+
# Verify model is in eval mode
|
| 539 |
+
print(f"Model training mode: {model.training}")
|
| 540 |
+
print(f"Dropout layers found: {sum(1 for _ in model.modules() if isinstance(_, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d)))}")
|
| 541 |
+
|
| 542 |
# Create a unified config dict with max_length at top level for compatibility
|
| 543 |
unified_config = config.copy()
|
| 544 |
if 'max_length' not in unified_config and 'training_config' in config:
|
|
|
|
| 546 |
|
| 547 |
return model, tokenizer_hatebert, tokenizer_rationale, unified_config, device
|
| 548 |
|
| 549 |
+
|
| 550 |
+
def combined_loss(logits, labels, rationale_probs, selector_logits, rationale_mask=None, attns=None, attn_weight=0.0, rationale_weight=1.0):
|
| 551 |
+
cls_loss = F.cross_entropy(logits, labels)
|
| 552 |
+
|
| 553 |
+
# supervise selector logits with BCE-with-logits against rationale mask (if available)
|
| 554 |
+
if rationale_mask is not None:
|
| 555 |
+
selector_loss = F.binary_cross_entropy_with_logits(selector_logits, rationale_mask.to(selector_logits.device))
|
| 556 |
+
else:
|
| 557 |
+
selector_loss = torch.tensor(0.0, device=cls_loss.device)
|
| 558 |
+
|
| 559 |
+
# optional attention alignment loss (disabled by default)
|
| 560 |
+
attn_loss = torch.tensor(0.0, device=cls_loss.device)
|
| 561 |
+
if attns is not None and attn_weight > 0.0:
|
| 562 |
+
try:
|
| 563 |
+
last_attn = attns[-1] # (B, H, L, L)
|
| 564 |
+
attn_mass = last_attn.mean(1).mean(1) # (B, L)
|
| 565 |
+
attn_loss = F.mse_loss(attn_mass, rationale_mask.to(attn_mass.device))
|
| 566 |
+
except Exception:
|
| 567 |
+
attn_loss = torch.tensor(0.0, device=cls_loss.device)
|
| 568 |
+
|
| 569 |
+
total_loss = cls_loss + rationale_weight * selector_loss + attn_weight * attn_loss
|
| 570 |
+
return total_loss, cls_loss.item(), selector_loss.item(), attn_loss.item()
|
| 571 |
+
|
| 572 |
+
|
| 573 |
def predict_text(text, rationale, model, tokenizer_hatebert, tokenizer_rationale,
|
| 574 |
device='cpu', max_length=128, model_type="altered"):
|
| 575 |
+
# Ensure model is in eval mode (defensive programming)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
model.eval()
|
| 577 |
|
| 578 |
# Tokenize inputs
|
|
|
|
| 598 |
add_input_ids = inputs_rationale['input_ids'].to(device)
|
| 599 |
add_attention_mask = inputs_rationale['attention_mask'].to(device)
|
| 600 |
|
| 601 |
+
# Inference with no gradient computation
|
| 602 |
+
with torch.no_grad():
|
| 603 |
+
if model_type.lower() == "base":
|
| 604 |
logits = model(
|
| 605 |
input_ids,
|
| 606 |
attention_mask,
|
| 607 |
add_input_ids,
|
| 608 |
add_attention_mask
|
| 609 |
)
|
| 610 |
+
else:
|
| 611 |
+
logits, rationale_probs, selector_logits, _ = model(
|
| 612 |
+
input_ids,
|
| 613 |
+
attention_mask,
|
| 614 |
+
add_input_ids,
|
| 615 |
+
add_attention_mask
|
| 616 |
+
)
|
| 617 |
|
| 618 |
+
temperature = 1 # Adjust this if needed (e.g., 2.0 for less confidence)
|
| 619 |
+
scaled_logits = logits / temperature
|
| 620 |
+
|
| 621 |
+
# Get probabilities with numerical stability
|
| 622 |
+
probs = F.softmax(scaled_logits, dim=1)
|
| 623 |
+
|
| 624 |
+
if torch.isnan(probs).any() or torch.isinf(probs).any():
|
| 625 |
+
print(f"WARNING: NaN or Inf in probabilities. Logits: {logits}")
|
| 626 |
+
# Fallback to uniform distribution
|
| 627 |
+
probs = torch.ones_like(logits) / logits.size(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
|
|
|
|
|
|
|
| 629 |
prediction = logits.argmax(dim=1).item()
|
| 630 |
confidence = probs[0, prediction].item()
|
| 631 |
+
|
| 632 |
+
# Debug: Print logits and probs for first few predictions
|
| 633 |
+
print(f"Debug - Logits: {logits[0].cpu().numpy()}, Probs: {probs[0].cpu().numpy()}")
|
| 634 |
|
| 635 |
+
result = {
|
| 636 |
'prediction': prediction,
|
| 637 |
'confidence': confidence,
|
| 638 |
'probabilities': probs[0].cpu().numpy(),
|
|
|
|
| 639 |
'tokens': tokenizer_hatebert.convert_ids_to_tokens(input_ids[0])
|
| 640 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
+
if model_type.lower() != "base":
|
| 643 |
+
result['rationale_scores'] = rationale_probs[0].cpu().numpy() if 'rationale_probs' in locals() else None
|
| 644 |
+
else:
|
| 645 |
+
result['rationale_scores'] = None
|
| 646 |
+
|
| 647 |
+
return result
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def predict_hatespeech_from_file(
|
| 651 |
+
text_list,
|
| 652 |
+
rationale_list,
|
| 653 |
+
true_label,
|
| 654 |
+
model,
|
| 655 |
+
tokenizer_hatebert,
|
| 656 |
+
tokenizer_rationale,
|
| 657 |
+
config,
|
| 658 |
+
device,
|
| 659 |
+
model_type="altered"
|
| 660 |
+
):
|
| 661 |
+
|
| 662 |
+
print(f"\nStarting inference for model: {type(model).__name__}")
|
| 663 |
+
|
| 664 |
predictions = []
|
| 665 |
+
all_probs = []
|
| 666 |
cpu_percent_list = []
|
| 667 |
memory_percent_list = []
|
| 668 |
|
| 669 |
process = psutil.Process(os.getpid())
|
| 670 |
+
|
| 671 |
+
# 🔥 GPU synchronization BEFORE timing
|
| 672 |
+
if torch.cuda.is_available():
|
| 673 |
+
torch.cuda.synchronize()
|
| 674 |
+
|
| 675 |
+
# 🔥 Optional warmup (prevents first-batch timing bias)
|
| 676 |
+
with torch.no_grad():
|
| 677 |
+
_ = predict_text(
|
| 678 |
+
text=text_list[0],
|
| 679 |
+
rationale=rationale_list[0],
|
| 680 |
+
model=model,
|
| 681 |
+
tokenizer_hatebert=tokenizer_hatebert,
|
| 682 |
+
tokenizer_rationale=tokenizer_rationale,
|
| 683 |
+
device=device,
|
| 684 |
+
max_length=config.get('max_length', 128),
|
| 685 |
+
model_type=model_type
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
if torch.cuda.is_available():
|
| 689 |
+
torch.cuda.synchronize()
|
| 690 |
+
|
| 691 |
+
# ⏱ Start timer AFTER warmup
|
| 692 |
start_time = time()
|
| 693 |
+
|
| 694 |
for idx, (text, rationale) in enumerate(zip(text_list, rationale_list)):
|
| 695 |
result = predict_text(
|
| 696 |
text=text,
|
|
|
|
| 702 |
max_length=config.get('max_length', 128),
|
| 703 |
model_type=model_type
|
| 704 |
)
|
| 705 |
+
|
| 706 |
predictions.append(result['prediction'])
|
| 707 |
+
all_probs.append(result['probabilities'])
|
| 708 |
+
|
| 709 |
+
# Reduce monitoring overhead
|
| 710 |
if idx % 10 == 0 or idx == len(text_list) - 1:
|
| 711 |
cpu_percent_list.append(process.cpu_percent())
|
| 712 |
memory_percent_list.append(process.memory_info().rss / 1024 / 1024)
|
| 713 |
|
| 714 |
+
# 🔥 GPU synchronization BEFORE stopping timer
|
| 715 |
+
if torch.cuda.is_available():
|
| 716 |
+
torch.cuda.synchronize()
|
| 717 |
+
|
| 718 |
end_time = time()
|
| 719 |
runtime = end_time - start_time
|
| 720 |
+
|
| 721 |
+
print(f"Inference completed for {type(model).__name__}")
|
| 722 |
+
print(f"Total runtime: {runtime:.4f} seconds")
|
| 723 |
+
|
| 724 |
+
# ---------------- Metrics ----------------
|
| 725 |
+
all_probs = np.array(all_probs)
|
| 726 |
+
|
| 727 |
+
print(f"Probability Mean: {all_probs.mean(axis=0)}")
|
| 728 |
+
print(f"Probability Std: {all_probs.std(axis=0)}")
|
| 729 |
+
print(f"Prediction distribution: {np.bincount(predictions, minlength=2)}")
|
| 730 |
+
|
| 731 |
f1 = f1_score(true_label, predictions, zero_division=0)
|
| 732 |
accuracy = accuracy_score(true_label, predictions)
|
| 733 |
precision = precision_score(true_label, predictions, zero_division=0)
|
| 734 |
recall = recall_score(true_label, predictions, zero_division=0)
|
| 735 |
cm = confusion_matrix(true_label, predictions).tolist()
|
| 736 |
+
|
| 737 |
avg_cpu = sum(cpu_percent_list) / len(cpu_percent_list) if cpu_percent_list else 0
|
| 738 |
+
avg_memory = sum(memory_percent_list) / len(memory_percent_list) if memory_percent_list else 0
|
| 739 |
peak_memory = max(memory_percent_list) if memory_percent_list else 0
|
| 740 |
peak_cpu = max(cpu_percent_list) if cpu_percent_list else 0
|
| 741 |
|
| 742 |
return {
|
| 743 |
+
'model_name': type(model).__name__, # 👈 makes logs clearer
|
| 744 |
'f1_score': f1,
|
| 745 |
'accuracy': accuracy,
|
| 746 |
'precision': precision,
|
|
|
|
| 750 |
'memory_usage': avg_memory,
|
| 751 |
'peak_cpu_usage': peak_cpu,
|
| 752 |
'peak_memory_usage': peak_memory,
|
| 753 |
+
'runtime': runtime,
|
| 754 |
+
'all_probabilities': all_probs.tolist()
|
| 755 |
}
|
| 756 |
|
| 757 |
|
| 758 |
def predict_hatespeech(text, rationale, model, tokenizer_hatebert, tokenizer_rationale, config, device, model_type="altered"):
|
| 759 |
"""
|
| 760 |
Predict hate speech for given text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
"""
|
| 762 |
# Get prediction
|
| 763 |
result = predict_text(
|
|
|
|
| 772 |
)
|
| 773 |
|
| 774 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|