Add README, tokenizer, results
Browse files- main.py +354 -0
- predict.py +185 -0
main.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""glove+bilstm.ipynb
|
| 3 |
+
|
| 4 |
+
Automatically generated by Colab.
|
| 5 |
+
|
| 6 |
+
Original file is located at
|
| 7 |
+
https://colab.research.google.com/drive/10fLw7V6G3vV_STF7KcWe8qcTvyLQq0NT
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
import seaborn as sns
|
| 15 |
+
from itertools import permutations
|
| 16 |
+
|
| 17 |
+
# For train-test split and evaluation
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from sklearn.metrics import accuracy_score, balanced_accuracy_score
|
| 20 |
+
from sklearn.metrics import precision_score, recall_score, f1_score
|
| 21 |
+
from sklearn.metrics import roc_auc_score, confusion_matrix
|
| 22 |
+
from sklearn.metrics import roc_curve, precision_recall_curve
|
| 23 |
+
|
| 24 |
+
# Deep learning libraries
|
| 25 |
+
from tensorflow.keras.preprocessing.text import Tokenizer
|
| 26 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 27 |
+
from tensorflow.keras.models import Sequential
|
| 28 |
+
from tensorflow.keras.layers import Embedding, Bidirectional, LSTM
|
| 29 |
+
from tensorflow.keras.layers import Dense, Dropout
|
| 30 |
+
|
| 31 |
+
base_path = "/root/output"
|
| 32 |
+
|
| 33 |
+
os.makedirs(base_path+"/dataset_splits", exist_ok=True)
|
| 34 |
+
os.makedirs(base_path+"/figures", exist_ok=True)
|
| 35 |
+
os.makedirs(base_path+"/results_tables", exist_ok=True)
|
| 36 |
+
os.makedirs(base_path+"/trained_models", exist_ok=True)
|
| 37 |
+
|
| 38 |
+
data_path = "/root/dataset.csv"
|
| 39 |
+
|
| 40 |
+
df = pd.read_csv(data_path)
|
| 41 |
+
|
| 42 |
+
df.head()
|
| 43 |
+
|
| 44 |
+
plt.figure(figsize=(6,4))
|
| 45 |
+
df['language'].value_counts().plot.pie(autopct='%1.1f%%')
|
| 46 |
+
plt.title("Dataset Language Distribution")
|
| 47 |
+
plt.ylabel("")
|
| 48 |
+
plt.savefig(base_path+"/figures/language_distribution.png", dpi=300)
|
| 49 |
+
plt.show()
|
| 50 |
+
|
| 51 |
+
X = df["clean_text"]
|
| 52 |
+
y = df["hate_label"]
|
| 53 |
+
lang = df["language"]
|
| 54 |
+
|
| 55 |
+
X_temp, X_test, y_temp, y_test, lang_temp, lang_test = train_test_split(
|
| 56 |
+
X, y, lang, test_size=0.30, stratify=y, random_state=42)
|
| 57 |
+
|
| 58 |
+
X_train, X_val, y_train, y_val, lang_train, lang_val = train_test_split(
|
| 59 |
+
X_temp, y_temp, lang_temp,
|
| 60 |
+
test_size=0.1428,
|
| 61 |
+
stratify=y_temp,
|
| 62 |
+
random_state=42
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
pd.DataFrame({"text":X_train,"label":y_train,"lang":lang_train}).to_csv(
|
| 66 |
+
base_path+"/dataset_splits/train.csv", index=False)
|
| 67 |
+
|
| 68 |
+
pd.DataFrame({"text":X_val,"label":y_val,"lang":lang_val}).to_csv(
|
| 69 |
+
base_path+"/dataset_splits/val.csv", index=False)
|
| 70 |
+
|
| 71 |
+
pd.DataFrame({"text":X_test,"label":y_test,"lang":lang_test}).to_csv(
|
| 72 |
+
base_path+"/dataset_splits/test.csv", index=False)
|
| 73 |
+
|
| 74 |
+
MAX_LEN = 100
|
| 75 |
+
VOCAB = 50000
|
| 76 |
+
|
| 77 |
+
tokenizer = Tokenizer(num_words=VOCAB)
|
| 78 |
+
tokenizer.fit_on_texts(X_train)
|
| 79 |
+
|
| 80 |
+
X_train_seq = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=MAX_LEN)
|
| 81 |
+
X_val_seq = pad_sequences(tokenizer.texts_to_sequences(X_val), maxlen=MAX_LEN)
|
| 82 |
+
X_test_seq = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=MAX_LEN)
|
| 83 |
+
|
| 84 |
+
EMBEDDING_DIM = 300
|
| 85 |
+
glove_path = "/root/glove.6B.300d.txt"
|
| 86 |
+
|
| 87 |
+
embeddings_index = {}
|
| 88 |
+
|
| 89 |
+
with open(glove_path, encoding="utf8") as f:
|
| 90 |
+
for line in f:
|
| 91 |
+
values = line.split()
|
| 92 |
+
word = values[0]
|
| 93 |
+
vector = np.asarray(values[1:], dtype="float32")
|
| 94 |
+
embeddings_index[word] = vector
|
| 95 |
+
|
| 96 |
+
print("Loaded %s word vectors." % len(embeddings_index))
|
| 97 |
+
|
| 98 |
+
word_index = tokenizer.word_index
|
| 99 |
+
embedding_dim = 300
|
| 100 |
+
|
| 101 |
+
embedding_matrix = np.zeros((len(word_index)+1, embedding_dim))
|
| 102 |
+
|
| 103 |
+
for word, i in word_index.items():
|
| 104 |
+
vector = embeddings_index.get(word)
|
| 105 |
+
if vector is not None:
|
| 106 |
+
embedding_matrix[i] = vector
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ============================================================
|
| 110 |
+
# Helper: build a fresh model (called once per permutation)
|
| 111 |
+
# ============================================================
|
| 112 |
+
def build_model():
|
| 113 |
+
"""Construct and compile a fresh BiLSTM model with frozen GloVe embeddings."""
|
| 114 |
+
m = Sequential()
|
| 115 |
+
m.add(Embedding(
|
| 116 |
+
input_dim=len(word_index)+1,
|
| 117 |
+
output_dim=embedding_dim,
|
| 118 |
+
weights=[embedding_matrix],
|
| 119 |
+
input_length=MAX_LEN,
|
| 120 |
+
trainable=False
|
| 121 |
+
))
|
| 122 |
+
m.add(Bidirectional(LSTM(128)))
|
| 123 |
+
m.add(Dropout(0.5))
|
| 124 |
+
m.add(Dense(64, activation="relu"))
|
| 125 |
+
m.add(Dense(1, activation="sigmoid"))
|
| 126 |
+
m.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
|
| 127 |
+
return m
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def evaluate_metrics(y_true, y_pred_prob):
|
| 131 |
+
y_pred = (y_pred_prob > 0.5).astype(int)
|
| 132 |
+
acc = accuracy_score(y_true, y_pred)
|
| 133 |
+
bal = balanced_accuracy_score(y_true, y_pred)
|
| 134 |
+
prec = precision_score(y_true, y_pred)
|
| 135 |
+
rec = recall_score(y_true, y_pred)
|
| 136 |
+
f1 = f1_score(y_true, y_pred)
|
| 137 |
+
auc = roc_auc_score(y_true, y_pred_prob)
|
| 138 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
| 139 |
+
spec = tn / (tn + fp)
|
| 140 |
+
return acc, bal, prec, rec, spec, f1, auc
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def plot_training_curves(history, tag, base_path):
|
| 144 |
+
"""Save accuracy and loss curves for one training phase."""
|
| 145 |
+
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 146 |
+
|
| 147 |
+
axes[0].plot(history.history['accuracy'], label="Train Accuracy")
|
| 148 |
+
axes[0].plot(history.history['val_accuracy'], label="Val Accuracy")
|
| 149 |
+
axes[0].set_title(f"{tag} - Accuracy Curve")
|
| 150 |
+
axes[0].set_xlabel("Epoch")
|
| 151 |
+
axes[0].set_ylabel("Accuracy")
|
| 152 |
+
axes[0].legend()
|
| 153 |
+
axes[0].grid(True)
|
| 154 |
+
|
| 155 |
+
axes[1].plot(history.history['loss'], label="Train Loss")
|
| 156 |
+
axes[1].plot(history.history['val_loss'], label="Val Loss")
|
| 157 |
+
axes[1].set_title(f"{tag} - Loss Curve")
|
| 158 |
+
axes[1].set_xlabel("Epoch")
|
| 159 |
+
axes[1].set_ylabel("Loss")
|
| 160 |
+
axes[1].legend()
|
| 161 |
+
axes[1].grid(True)
|
| 162 |
+
|
| 163 |
+
plt.tight_layout()
|
| 164 |
+
fname = tag.replace(" -> ", "_to_").replace(" ", "_")
|
| 165 |
+
plt.savefig(os.path.join(base_path, f"{fname}_curves.png"), dpi=300)
|
| 166 |
+
plt.show()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def plot_eval_charts(y_test, preds, tag, base_path):
|
| 170 |
+
"""Save confusion matrix, ROC, PR, and F1 curves after evaluation."""
|
| 171 |
+
fname = tag.replace(" -> ", "_to_").replace(" ", "_")
|
| 172 |
+
|
| 173 |
+
# Confusion Matrix
|
| 174 |
+
cm = confusion_matrix(y_test, (preds > 0.5).astype(int))
|
| 175 |
+
plt.figure(figsize=(6,4))
|
| 176 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 177 |
+
xticklabels=["Non-Hate","Hate"],
|
| 178 |
+
yticklabels=["Non-Hate","Hate"])
|
| 179 |
+
plt.title(f"{tag} - Confusion Matrix")
|
| 180 |
+
plt.xlabel("Predicted")
|
| 181 |
+
plt.ylabel("Actual")
|
| 182 |
+
plt.savefig(os.path.join(base_path, f"{fname}_cm.png"), dpi=300)
|
| 183 |
+
plt.show()
|
| 184 |
+
|
| 185 |
+
# ROC Curve
|
| 186 |
+
fpr, tpr, _ = roc_curve(y_test, preds)
|
| 187 |
+
auc_val = roc_auc_score(y_test, preds)
|
| 188 |
+
plt.figure(figsize=(6,4))
|
| 189 |
+
plt.plot(fpr, tpr, label=f"AUC={auc_val:.4f}")
|
| 190 |
+
plt.plot([0,1],[0,1],'--')
|
| 191 |
+
plt.title(f"{tag} - ROC Curve")
|
| 192 |
+
plt.xlabel("FPR")
|
| 193 |
+
plt.ylabel("TPR")
|
| 194 |
+
plt.legend()
|
| 195 |
+
plt.grid(True)
|
| 196 |
+
plt.savefig(os.path.join(base_path, f"{fname}_roc.png"), dpi=300)
|
| 197 |
+
plt.show()
|
| 198 |
+
|
| 199 |
+
# Precision-Recall Curve
|
| 200 |
+
precision, recall, thresholds = precision_recall_curve(y_test, preds)
|
| 201 |
+
plt.figure(figsize=(6,4))
|
| 202 |
+
plt.plot(recall, precision)
|
| 203 |
+
plt.title(f"{tag} - Precision-Recall Curve")
|
| 204 |
+
plt.xlabel("Recall")
|
| 205 |
+
plt.ylabel("Precision")
|
| 206 |
+
plt.grid(True)
|
| 207 |
+
plt.savefig(os.path.join(base_path, f"{fname}_pr.png"), dpi=300)
|
| 208 |
+
plt.show()
|
| 209 |
+
|
| 210 |
+
# F1 Curve
|
| 211 |
+
f1_scores = (2 * precision * recall) / (precision + recall + 1e-8)
|
| 212 |
+
plt.figure(figsize=(6,4))
|
| 213 |
+
plt.plot(thresholds, f1_scores[:-1])
|
| 214 |
+
plt.title(f"{tag} - F1 Score vs Threshold")
|
| 215 |
+
plt.xlabel("Threshold")
|
| 216 |
+
plt.ylabel("F1 Score")
|
| 217 |
+
plt.grid(True)
|
| 218 |
+
plt.savefig(os.path.join(base_path, f"{fname}_f1.png"), dpi=300)
|
| 219 |
+
plt.show()
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ============================================================
|
| 223 |
+
# PLAN B: All 6 permutations + final Full (Shuffled) fine-tune
|
| 224 |
+
# After each training phase β evaluate on that language's test set
|
| 225 |
+
# After Full phase β evaluate on full test set
|
| 226 |
+
# ============================================================
|
| 227 |
+
print("\n" + "="*60)
|
| 228 |
+
print("PLAN B: Sequential Transfer Learning + Full Dataset Fine-tune")
|
| 229 |
+
print("="*60)
|
| 230 |
+
|
| 231 |
+
languages = ["english", "hindi", "hinglish"]
|
| 232 |
+
|
| 233 |
+
# Pre-shuffle full training data once (same shuffle for all permutations)
|
| 234 |
+
np.random.seed(42)
|
| 235 |
+
shuffle_idx = np.random.permutation(len(X_train_seq))
|
| 236 |
+
X_full_shuffled = np.ascontiguousarray(X_train_seq[shuffle_idx], dtype=np.int32)
|
| 237 |
+
y_full_shuffled = np.ascontiguousarray(y_train.values[shuffle_idx], dtype=np.float32)
|
| 238 |
+
|
| 239 |
+
# Pre-build per-language test splits
|
| 240 |
+
lang_test_idx = {
|
| 241 |
+
lang: (lang_test.values == lang)
|
| 242 |
+
for lang in languages
|
| 243 |
+
}
|
| 244 |
+
lang_test_X = {
|
| 245 |
+
lang: X_test_seq[lang_test_idx[lang]]
|
| 246 |
+
for lang in languages
|
| 247 |
+
}
|
| 248 |
+
lang_test_y = {
|
| 249 |
+
lang: y_test.values[lang_test_idx[lang]]
|
| 250 |
+
for lang in languages
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
cols = ["Strategy", "Phase", "Accuracy", "Balanced Acc",
|
| 254 |
+
"Precision", "Recall", "Specificity", "F1", "ROC-AUC"]
|
| 255 |
+
|
| 256 |
+
for perm in permutations(languages):
|
| 257 |
+
perm_name = " -> ".join(perm)
|
| 258 |
+
strategy_name = perm_name + " -> Full"
|
| 259 |
+
strategy_results = []
|
| 260 |
+
|
| 261 |
+
print(f"\n{'='*50}")
|
| 262 |
+
print(f"Strategy: {strategy_name}")
|
| 263 |
+
print(f"{'='*50}")
|
| 264 |
+
|
| 265 |
+
# Make a clean folder per strategy for figures
|
| 266 |
+
strat_tag = perm_name.replace(" -> ", "_to_")
|
| 267 |
+
strat_fig_path = base_path + f"/figures/{strat_tag}"
|
| 268 |
+
os.makedirs(strat_fig_path, exist_ok=True)
|
| 269 |
+
|
| 270 |
+
# Model built ONCE β weights carry forward across all phases
|
| 271 |
+
model = build_model()
|
| 272 |
+
|
| 273 |
+
# ββ Language phases ββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
for lang in perm:
|
| 275 |
+
idx = (lang_train == lang)
|
| 276 |
+
X_lang = X_train_seq[idx]
|
| 277 |
+
y_lang = y_train[idx]
|
| 278 |
+
|
| 279 |
+
print(f" Training on: {lang} ({X_lang.shape[0]} samples)")
|
| 280 |
+
|
| 281 |
+
history = model.fit(
|
| 282 |
+
X_lang, y_lang,
|
| 283 |
+
validation_data=(X_val_seq, y_val),
|
| 284 |
+
epochs=8,
|
| 285 |
+
batch_size=32,
|
| 286 |
+
verbose=1
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Train/Val accuracy + loss curves
|
| 290 |
+
plot_training_curves(history, f"{strat_tag} [{lang}]", strat_fig_path)
|
| 291 |
+
|
| 292 |
+
# Evaluate on this language's test subset
|
| 293 |
+
preds = model.predict(lang_test_X[lang]).flatten()
|
| 294 |
+
acc, bal, prec, rec, spec, f1, auc = evaluate_metrics(lang_test_y[lang], preds)
|
| 295 |
+
strategy_results.append([strategy_name, lang, acc, bal, prec, rec, spec, f1, auc])
|
| 296 |
+
|
| 297 |
+
# Eval plots for this language
|
| 298 |
+
plot_eval_charts(lang_test_y[lang], preds, f"{strat_tag} [{lang}]", strat_fig_path)
|
| 299 |
+
|
| 300 |
+
print(f" Acc={acc:.4f} F1={f1:.4f} AUC={auc:.4f}")
|
| 301 |
+
|
| 302 |
+
# ββ Full phase βββββββββββββββββββββββββββββββββββββββββββ
|
| 303 |
+
print(f" Training on: Full Dataset ({X_full_shuffled.shape[0]} samples, shuffled)")
|
| 304 |
+
|
| 305 |
+
history_full = model.fit(
|
| 306 |
+
X_full_shuffled, y_full_shuffled,
|
| 307 |
+
validation_data=(X_val_seq, y_val),
|
| 308 |
+
epochs=8,
|
| 309 |
+
batch_size=64,
|
| 310 |
+
verbose=1
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
# Train/Val accuracy + loss curves for full phase
|
| 314 |
+
plot_training_curves(history_full, f"{strat_tag} [Full]", strat_fig_path)
|
| 315 |
+
|
| 316 |
+
# Evaluate on full test set
|
| 317 |
+
preds_full = model.predict(X_test_seq).flatten()
|
| 318 |
+
acc, bal, prec, rec, spec, f1, auc = evaluate_metrics(y_test.values, preds_full)
|
| 319 |
+
strategy_results.append([strategy_name, "Full", acc, bal, prec, rec, spec, f1, auc])
|
| 320 |
+
|
| 321 |
+
# Eval plots for full phase
|
| 322 |
+
plot_eval_charts(y_test.values, preds_full, f"{strat_tag} [Full]", strat_fig_path)
|
| 323 |
+
|
| 324 |
+
print(f" Acc={acc:.4f} F1={f1:.4f} AUC={auc:.4f}")
|
| 325 |
+
|
| 326 |
+
# Save per-strategy results table (4 rows: 3 langs + Full)
|
| 327 |
+
strat_df = pd.DataFrame(strategy_results, columns=cols)
|
| 328 |
+
strat_df.to_csv(
|
| 329 |
+
base_path + f"/results_tables/{strat_tag}_results.csv",
|
| 330 |
+
index=False
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
print(f"\n Results for strategy: {strategy_name}")
|
| 334 |
+
print(strat_df.to_string(index=False))
|
| 335 |
+
|
| 336 |
+
model.save(base_path + f"/trained_models/planB_{strat_tag}_Full.h5")
|
| 337 |
+
print(f" Saved model: planB_{strat_tag}_Full.h5")
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ============================================================
|
| 341 |
+
# COMBINED RESULTS TABLE (all 6 strategies Γ 4 phases = 24 rows)
|
| 342 |
+
# ============================================================
|
| 343 |
+
all_csv = [
|
| 344 |
+
base_path + f"/results_tables/{('_to_'.join(perm))}_results.csv"
|
| 345 |
+
for perm in permutations(languages)
|
| 346 |
+
]
|
| 347 |
+
|
| 348 |
+
combined_df = pd.concat([pd.read_csv(f) for f in all_csv], ignore_index=True)
|
| 349 |
+
combined_df.to_csv(base_path + "/results_tables/all_strategies_results.csv", index=False)
|
| 350 |
+
|
| 351 |
+
print("\n" + "="*60)
|
| 352 |
+
print("ALL STRATEGIES β COMBINED RESULTS")
|
| 353 |
+
print("="*60)
|
| 354 |
+
print(combined_df.to_string(index=False))
|
predict.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
predict.py β Interactive inference script for the SASC hate speech detection model.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python predict.py # fully interactive
|
| 7 |
+
python predict.py --model model.h5 # specify model path
|
| 8 |
+
python predict.py --input texts.csv # specify input CSV
|
| 9 |
+
python predict.py --text "some text here" # single text prediction
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ββ Argument parsing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
|
| 20 |
+
parser.add_argument("--model", type=str, help="Path to .h5 model file")
|
| 21 |
+
parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
|
| 22 |
+
parser.add_argument("--input", type=str, help="Path to input CSV file")
|
| 23 |
+
parser.add_argument("--text", type=str, help="Single text to classify")
|
| 24 |
+
parser.add_argument("--output", type=str, help="Path to save results CSV")
|
| 25 |
+
parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
|
| 26 |
+
parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ββ Interactive prompts if args not provided βββββββββββββββββββββββββββββββββ
|
| 31 |
+
def ask(prompt, default=None):
|
| 32 |
+
suffix = f" [{default}]" if default else ""
|
| 33 |
+
val = input(f"{prompt}{suffix}: ").strip()
|
| 34 |
+
return val if val else default
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
print("\n=== SASC Hate Speech Detector ===\n")
|
| 38 |
+
|
| 39 |
+
# Model path
|
| 40 |
+
model_path = args.model
|
| 41 |
+
if not model_path:
|
| 42 |
+
model_path = ask("Model path (.h5)", "model.h5")
|
| 43 |
+
|
| 44 |
+
if not os.path.exists(model_path):
|
| 45 |
+
print(f"Model not found: {model_path}")
|
| 46 |
+
sys.exit(1)
|
| 47 |
+
|
| 48 |
+
# Tokenizer path
|
| 49 |
+
tokenizer_path = args.tokenizer
|
| 50 |
+
if not tokenizer_path:
|
| 51 |
+
# look next to model file first
|
| 52 |
+
candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
|
| 53 |
+
tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json")
|
| 54 |
+
|
| 55 |
+
if not os.path.exists(tokenizer_path):
|
| 56 |
+
print(f"Tokenizer not found: {tokenizer_path}")
|
| 57 |
+
sys.exit(1)
|
| 58 |
+
|
| 59 |
+
# Threshold
|
| 60 |
+
threshold = args.threshold
|
| 61 |
+
if not args.threshold and not args.text and not args.input:
|
| 62 |
+
t = ask("Decision threshold (0.0-1.0)", "0.5")
|
| 63 |
+
try:
|
| 64 |
+
threshold = float(t)
|
| 65 |
+
except ValueError:
|
| 66 |
+
threshold = 0.5
|
| 67 |
+
|
| 68 |
+
print(f"\nLoading model from {model_path}...")
|
| 69 |
+
import tensorflow as tf
|
| 70 |
+
model = tf.keras.models.load_model(model_path)
|
| 71 |
+
|
| 72 |
+
print(f"Loading tokenizer from {tokenizer_path}...")
|
| 73 |
+
from tensorflow.keras.preprocessing.text import tokenizer_from_json
|
| 74 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 75 |
+
with open(tokenizer_path) as f:
|
| 76 |
+
tokenizer = tokenizer_from_json(f.read())
|
| 77 |
+
|
| 78 |
+
MAX_LEN = 100
|
| 79 |
+
|
| 80 |
+
def predict(texts):
|
| 81 |
+
seqs = tokenizer.texts_to_sequences(texts)
|
| 82 |
+
padded = pad_sequences(seqs, maxlen=MAX_LEN)
|
| 83 |
+
probs = model.predict(padded, verbose=0).flatten()
|
| 84 |
+
labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
|
| 85 |
+
return probs, labels
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ββ Single text mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
if args.text:
|
| 90 |
+
probs, labels = predict([args.text])
|
| 91 |
+
print(f"\nText : {args.text}")
|
| 92 |
+
print(f"Label : {labels[0]}")
|
| 93 |
+
print(f"Confidence: {probs[0]:.4f}")
|
| 94 |
+
sys.exit(0)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ββ CSV mode ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 98 |
+
import pandas as pd
|
| 99 |
+
|
| 100 |
+
input_path = args.input
|
| 101 |
+
if not input_path:
|
| 102 |
+
mode = ask("Input mode β (1) CSV file (2) Type text manually", "1")
|
| 103 |
+
|
| 104 |
+
if mode == "2":
|
| 105 |
+
# manual text entry loop
|
| 106 |
+
print("\nEnter texts one per line. Type 'done' when finished.\n")
|
| 107 |
+
texts = []
|
| 108 |
+
while True:
|
| 109 |
+
t = input(" Text: ").strip()
|
| 110 |
+
if t.lower() == "done":
|
| 111 |
+
break
|
| 112 |
+
if t:
|
| 113 |
+
texts.append(t)
|
| 114 |
+
|
| 115 |
+
if not texts:
|
| 116 |
+
print("No texts entered.")
|
| 117 |
+
sys.exit(0)
|
| 118 |
+
|
| 119 |
+
probs, labels = predict(texts)
|
| 120 |
+
import pandas as pd
|
| 121 |
+
results = pd.DataFrame({
|
| 122 |
+
"text": texts,
|
| 123 |
+
"label": labels,
|
| 124 |
+
"confidence": [round(float(p), 4) for p in probs]
|
| 125 |
+
})
|
| 126 |
+
|
| 127 |
+
print("\n" + "="*60)
|
| 128 |
+
print(results.to_string(index=False))
|
| 129 |
+
print("="*60)
|
| 130 |
+
|
| 131 |
+
out = args.output or ask("Save results to CSV? (leave blank to skip)", "")
|
| 132 |
+
if out:
|
| 133 |
+
results.to_csv(out, index=False)
|
| 134 |
+
print(f"Saved to {out}")
|
| 135 |
+
sys.exit(0)
|
| 136 |
+
|
| 137 |
+
else:
|
| 138 |
+
input_path = ask("CSV file path")
|
| 139 |
+
|
| 140 |
+
if not os.path.exists(input_path):
|
| 141 |
+
print(f"File not found: {input_path}")
|
| 142 |
+
sys.exit(1)
|
| 143 |
+
|
| 144 |
+
df = pd.read_csv(input_path)
|
| 145 |
+
print(f"\nLoaded {len(df)} rows from {input_path}")
|
| 146 |
+
print(f"Columns: {list(df.columns)}")
|
| 147 |
+
|
| 148 |
+
text_col = args.col
|
| 149 |
+
if text_col not in df.columns:
|
| 150 |
+
print(f"\nColumn '{text_col}' not found.")
|
| 151 |
+
text_col = ask(f"Which column contains the text?", df.columns[0])
|
| 152 |
+
|
| 153 |
+
print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")
|
| 154 |
+
|
| 155 |
+
texts = df[text_col].fillna("").astype(str).tolist()
|
| 156 |
+
probs, labels = predict(texts)
|
| 157 |
+
|
| 158 |
+
df["predicted_label"] = labels
|
| 159 |
+
df["confidence"] = [round(float(p), 4) for p in probs]
|
| 160 |
+
|
| 161 |
+
# Summary
|
| 162 |
+
hate_count = labels.count("Hate Speech")
|
| 163 |
+
nonhate_count = labels.count("Non-Hate")
|
| 164 |
+
print(f"\n{'='*60}")
|
| 165 |
+
print(f"Results Summary")
|
| 166 |
+
print(f"{'='*60}")
|
| 167 |
+
print(f" Total samples : {len(texts)}")
|
| 168 |
+
print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
|
| 169 |
+
print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
|
| 170 |
+
print(f" Threshold : {threshold}")
|
| 171 |
+
print(f"{'='*60}")
|
| 172 |
+
|
| 173 |
+
# Show sample
|
| 174 |
+
print(f"\nSample predictions (first 10):")
|
| 175 |
+
print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))
|
| 176 |
+
|
| 177 |
+
# Save
|
| 178 |
+
output_path = args.output
|
| 179 |
+
if not output_path:
|
| 180 |
+
default_out = input_path.replace(".csv", "_predictions.csv")
|
| 181 |
+
output_path = ask(f"\nSave full results to CSV", default_out)
|
| 182 |
+
|
| 183 |
+
if output_path:
|
| 184 |
+
df.to_csv(output_path, index=False)
|
| 185 |
+
print(f"\nSaved {len(df)} predictions to {output_path}")
|