Add scripts
Browse files- main_v2.py +303 -0
- predict.py +206 -0
- pyproject.toml +16 -0
main_v2.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Strategy: Hinglish -> Hindi -> English -> Full
|
| 4 |
+
- 50 epochs per phase (200 total)
|
| 5 |
+
- Evaluate on each individual language + full after every phase
|
| 6 |
+
- All figures: figsize=(8,6), dpi=300
|
| 7 |
+
- Output dir: /root/output_v2 (old output_v1 untouched)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import matplotlib
|
| 14 |
+
matplotlib.use("Agg")
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
|
| 18 |
+
from sklearn.model_selection import train_test_split
|
| 19 |
+
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,
|
| 20 |
+
precision_score, recall_score, f1_score,
|
| 21 |
+
roc_auc_score, confusion_matrix,
|
| 22 |
+
roc_curve, precision_recall_curve)
|
| 23 |
+
|
| 24 |
+
from tensorflow.keras.preprocessing.text import Tokenizer
|
| 25 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 26 |
+
from tensorflow.keras.models import Sequential
|
| 27 |
+
from tensorflow.keras.layers import Embedding, Bidirectional, LSTM, Dense, Dropout
|
| 28 |
+
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
|
| 29 |
+
|
| 30 |
+
# ── Paths ────────────────────────────────────────────────────────────────────
|
| 31 |
+
base_path = "/root/output_v2"
|
| 32 |
+
data_path = "/root/dataset.csv"
|
| 33 |
+
glove_path = "/root/glove.6B.300d.txt"
|
| 34 |
+
|
| 35 |
+
for sub in ["dataset_splits", "figures", "results_tables", "trained_models"]:
|
| 36 |
+
os.makedirs(os.path.join(base_path, sub), exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# ── Load data ────────────────────────────────────────────────────────────────
|
| 39 |
+
df = pd.read_csv(data_path)
|
| 40 |
+
|
| 41 |
+
# Language distribution pie
|
| 42 |
+
plt.figure(figsize=(8, 6))
|
| 43 |
+
df['language'].value_counts().plot.pie(autopct='%1.1f%%')
|
| 44 |
+
plt.title("Dataset Language Distribution")
|
| 45 |
+
plt.ylabel("")
|
| 46 |
+
plt.savefig(os.path.join(base_path, "figures", "language_distribution.png"), dpi=300, bbox_inches="tight")
|
| 47 |
+
plt.close()
|
| 48 |
+
|
| 49 |
+
X = df["clean_text"]
|
| 50 |
+
y = df["hate_label"]
|
| 51 |
+
lang = df["language"]
|
| 52 |
+
|
| 53 |
+
# ── Splits ───────────────────────────────────────────────────────────────────
|
| 54 |
+
X_temp, X_test, y_temp, y_test, lang_temp, lang_test = train_test_split(
|
| 55 |
+
X, y, lang, test_size=0.30, stratify=y, random_state=42)
|
| 56 |
+
|
| 57 |
+
X_train, X_val, y_train, y_val, lang_train, lang_val = train_test_split(
|
| 58 |
+
X_temp, y_temp, lang_temp,
|
| 59 |
+
test_size=0.1428, stratify=y_temp, random_state=42)
|
| 60 |
+
|
| 61 |
+
pd.DataFrame({"text": X_train, "label": y_train, "lang": lang_train}).to_csv(
|
| 62 |
+
os.path.join(base_path, "dataset_splits", "train.csv"), index=False)
|
| 63 |
+
pd.DataFrame({"text": X_val, "label": y_val, "lang": lang_val}).to_csv(
|
| 64 |
+
os.path.join(base_path, "dataset_splits", "val.csv"), index=False)
|
| 65 |
+
pd.DataFrame({"text": X_test, "label": y_test, "lang": lang_test}).to_csv(
|
| 66 |
+
os.path.join(base_path, "dataset_splits", "test.csv"), index=False)
|
| 67 |
+
|
| 68 |
+
# ── Tokenise & pad ───────────────────────────────────────────────────────────
|
| 69 |
+
MAX_LEN = 100
|
| 70 |
+
VOCAB = 50000
|
| 71 |
+
|
| 72 |
+
tokenizer = Tokenizer(num_words=VOCAB)
|
| 73 |
+
tokenizer.fit_on_texts(X_train)
|
| 74 |
+
|
| 75 |
+
X_train_seq = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=MAX_LEN)
|
| 76 |
+
X_val_seq = pad_sequences(tokenizer.texts_to_sequences(X_val), maxlen=MAX_LEN)
|
| 77 |
+
X_test_seq = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=MAX_LEN)
|
| 78 |
+
|
| 79 |
+
# ── GloVe embeddings ─────────────────────────────────────────────────────────
|
| 80 |
+
EMBEDDING_DIM = 300
|
| 81 |
+
print("Loading GloVe …")
|
| 82 |
+
embeddings_index = {}
|
| 83 |
+
with open(glove_path, encoding="utf8") as f:
|
| 84 |
+
for line in f:
|
| 85 |
+
values = line.split()
|
| 86 |
+
embeddings_index[values[0]] = np.asarray(values[1:], dtype="float32")
|
| 87 |
+
print(f"Loaded {len(embeddings_index):,} word vectors.")
|
| 88 |
+
|
| 89 |
+
word_index = tokenizer.word_index
|
| 90 |
+
embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
|
| 91 |
+
for word, i in word_index.items():
|
| 92 |
+
vec = embeddings_index.get(word)
|
| 93 |
+
if vec is not None:
|
| 94 |
+
embedding_matrix[i] = vec
|
| 95 |
+
|
| 96 |
+
# ── Per-language test subsets ────────────────────────────────────────────────
|
| 97 |
+
languages = ["english", "hindi", "hinglish"]
|
| 98 |
+
lang_test_X = {la: X_test_seq[lang_test.values == la] for la in languages}
|
| 99 |
+
lang_test_y = {la: y_test.values[lang_test.values == la] for la in languages}
|
| 100 |
+
|
| 101 |
+
# ── Helpers ──────────────────────────────────────────────────────────────────
|
| 102 |
+
def build_model():
|
| 103 |
+
m = Sequential([
|
| 104 |
+
Embedding(len(word_index) + 1, EMBEDDING_DIM,
|
| 105 |
+
weights=[embedding_matrix], input_length=MAX_LEN, trainable=False),
|
| 106 |
+
Bidirectional(LSTM(128)),
|
| 107 |
+
Dropout(0.5),
|
| 108 |
+
Dense(64, activation="relu"),
|
| 109 |
+
Dense(1, activation="sigmoid"),
|
| 110 |
+
])
|
| 111 |
+
m.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
|
| 112 |
+
return m
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def evaluate_metrics(y_true, y_pred_prob):
|
| 116 |
+
y_pred = (y_pred_prob > 0.5).astype(int)
|
| 117 |
+
acc = accuracy_score(y_true, y_pred)
|
| 118 |
+
bal = balanced_accuracy_score(y_true, y_pred)
|
| 119 |
+
prec = precision_score(y_true, y_pred, zero_division=0)
|
| 120 |
+
rec = recall_score(y_true, y_pred, zero_division=0)
|
| 121 |
+
f1 = f1_score(y_true, y_pred, zero_division=0)
|
| 122 |
+
auc = roc_auc_score(y_true, y_pred_prob)
|
| 123 |
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
| 124 |
+
spec = tn / (tn + fp)
|
| 125 |
+
return acc, bal, prec, rec, spec, f1, auc
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def safe_tag(s):
|
| 129 |
+
return s.replace(" -> ", "_to_").replace(" ", "_")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def plot_training_curves(history, tag, fig_dir):
|
| 133 |
+
fig, axes = plt.subplots(1, 2, figsize=(8, 6))
|
| 134 |
+
axes[0].plot(history.history['accuracy'], label="Train Acc")
|
| 135 |
+
axes[0].plot(history.history['val_accuracy'], label="Val Acc")
|
| 136 |
+
axes[0].set_title(f"{tag} — Accuracy")
|
| 137 |
+
axes[0].set_xlabel("Epoch"); axes[0].set_ylabel("Accuracy")
|
| 138 |
+
axes[0].legend(); axes[0].grid(True)
|
| 139 |
+
|
| 140 |
+
axes[1].plot(history.history['loss'], label="Train Loss")
|
| 141 |
+
axes[1].plot(history.history['val_loss'], label="Val Loss")
|
| 142 |
+
axes[1].set_title(f"{tag} — Loss")
|
| 143 |
+
axes[1].set_xlabel("Epoch"); axes[1].set_ylabel("Loss")
|
| 144 |
+
axes[1].legend(); axes[1].grid(True)
|
| 145 |
+
|
| 146 |
+
plt.tight_layout()
|
| 147 |
+
plt.savefig(os.path.join(fig_dir, f"{safe_tag(tag)}_curves.png"), dpi=300, bbox_inches="tight")
|
| 148 |
+
plt.close()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def plot_eval_charts(y_true, preds, tag, fig_dir):
|
| 152 |
+
ftag = safe_tag(tag)
|
| 153 |
+
|
| 154 |
+
# Confusion matrix
|
| 155 |
+
cm = confusion_matrix(y_true, (preds > 0.5).astype(int))
|
| 156 |
+
plt.figure(figsize=(8, 6))
|
| 157 |
+
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
|
| 158 |
+
xticklabels=["Non-Hate", "Hate"],
|
| 159 |
+
yticklabels=["Non-Hate", "Hate"])
|
| 160 |
+
plt.title(f"{tag} — Confusion Matrix")
|
| 161 |
+
plt.xlabel("Predicted"); plt.ylabel("Actual")
|
| 162 |
+
plt.savefig(os.path.join(fig_dir, f"{ftag}_cm.png"), dpi=300, bbox_inches="tight")
|
| 163 |
+
plt.close()
|
| 164 |
+
|
| 165 |
+
# ROC
|
| 166 |
+
fpr, tpr, _ = roc_curve(y_true, preds)
|
| 167 |
+
auc_val = roc_auc_score(y_true, preds)
|
| 168 |
+
plt.figure(figsize=(8, 6))
|
| 169 |
+
plt.plot(fpr, tpr, label=f"AUC={auc_val:.4f}")
|
| 170 |
+
plt.plot([0, 1], [0, 1], '--')
|
| 171 |
+
plt.title(f"{tag} — ROC Curve")
|
| 172 |
+
plt.xlabel("FPR"); plt.ylabel("TPR")
|
| 173 |
+
plt.legend(); plt.grid(True)
|
| 174 |
+
plt.savefig(os.path.join(fig_dir, f"{ftag}_roc.png"), dpi=300, bbox_inches="tight")
|
| 175 |
+
plt.close()
|
| 176 |
+
|
| 177 |
+
# Precision-Recall
|
| 178 |
+
precision, recall, thresholds = precision_recall_curve(y_true, preds)
|
| 179 |
+
plt.figure(figsize=(8, 6))
|
| 180 |
+
plt.plot(recall, precision)
|
| 181 |
+
plt.title(f"{tag} — Precision-Recall Curve")
|
| 182 |
+
plt.xlabel("Recall"); plt.ylabel("Precision")
|
| 183 |
+
plt.grid(True)
|
| 184 |
+
plt.savefig(os.path.join(fig_dir, f"{ftag}_pr.png"), dpi=300, bbox_inches="tight")
|
| 185 |
+
plt.close()
|
| 186 |
+
|
| 187 |
+
# F1 vs Threshold
|
| 188 |
+
f1_scores = (2 * precision * recall) / (precision + recall + 1e-8)
|
| 189 |
+
plt.figure(figsize=(8, 6))
|
| 190 |
+
plt.plot(thresholds, f1_scores[:-1])
|
| 191 |
+
plt.title(f"{tag} — F1 Score vs Threshold")
|
| 192 |
+
plt.xlabel("Threshold"); plt.ylabel("F1 Score")
|
| 193 |
+
plt.grid(True)
|
| 194 |
+
plt.savefig(os.path.join(fig_dir, f"{ftag}_f1.png"), dpi=300, bbox_inches="tight")
|
| 195 |
+
plt.close()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ── Strategy ─────────────────────────────────────────────────────────────────
|
| 199 |
+
STRATEGY = ("hinglish", "hindi", "english")
|
| 200 |
+
EPOCHS = 50
|
| 201 |
+
BATCH_LANG = 32
|
| 202 |
+
BATCH_FULL = 64
|
| 203 |
+
|
| 204 |
+
strategy_name = " -> ".join(STRATEGY) + " -> Full"
|
| 205 |
+
print("\n" + "=" * 60)
|
| 206 |
+
print(f"Strategy: {strategy_name}")
|
| 207 |
+
print(f"Epochs per phase: {EPOCHS} (Total: {EPOCHS * 4})")
|
| 208 |
+
print("=" * 60)
|
| 209 |
+
|
| 210 |
+
fig_dir = os.path.join(base_path, "figures", safe_tag(" -> ".join(STRATEGY)))
|
| 211 |
+
os.makedirs(fig_dir, exist_ok=True)
|
| 212 |
+
|
| 213 |
+
# Full training data (pre-shuffled, used in final phase)
|
| 214 |
+
np.random.seed(42)
|
| 215 |
+
shuffle_idx = np.random.permutation(len(X_train_seq))
|
| 216 |
+
X_full_shuffled = np.ascontiguousarray(X_train_seq[shuffle_idx], dtype=np.int32)
|
| 217 |
+
y_full_shuffled = np.ascontiguousarray(y_train.values[shuffle_idx], dtype=np.float32)
|
| 218 |
+
|
| 219 |
+
cols = ["Phase", "Eval_On", "Accuracy", "Balanced_Acc",
|
| 220 |
+
"Precision", "Recall", "Specificity", "F1", "ROC_AUC"]
|
| 221 |
+
all_rows = []
|
| 222 |
+
|
| 223 |
+
model = build_model()
|
| 224 |
+
model.summary()
|
| 225 |
+
|
| 226 |
+
# ── Language phases ──────────────────────────────────────────────────────────
|
| 227 |
+
for phase_lang in STRATEGY:
|
| 228 |
+
idx = (lang_train == phase_lang)
|
| 229 |
+
X_lang = X_train_seq[idx]
|
| 230 |
+
y_lang = y_train[idx]
|
| 231 |
+
|
| 232 |
+
print(f"\n{'─'*50}")
|
| 233 |
+
print(f"Phase: training on '{phase_lang}' ({X_lang.shape[0]} samples, {EPOCHS} epochs)")
|
| 234 |
+
print(f"{'─'*50}")
|
| 235 |
+
|
| 236 |
+
history = model.fit(
|
| 237 |
+
X_lang, y_lang,
|
| 238 |
+
validation_data=(X_val_seq, y_val),
|
| 239 |
+
epochs=EPOCHS,
|
| 240 |
+
batch_size=BATCH_LANG,
|
| 241 |
+
verbose=1,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
plot_training_curves(history, f"Phase_{phase_lang}", fig_dir)
|
| 245 |
+
|
| 246 |
+
# Evaluate on every individual language + full
|
| 247 |
+
for eval_lang in languages:
|
| 248 |
+
preds = model.predict(lang_test_X[eval_lang]).flatten()
|
| 249 |
+
metrics = evaluate_metrics(lang_test_y[eval_lang], preds)
|
| 250 |
+
all_rows.append([phase_lang, eval_lang] + list(metrics))
|
| 251 |
+
plot_eval_charts(lang_test_y[eval_lang], preds,
|
| 252 |
+
f"Phase_{phase_lang}_eval_{eval_lang}", fig_dir)
|
| 253 |
+
print(f" eval on {eval_lang:10s} | Acc={metrics[0]:.4f} F1={metrics[5]:.4f} AUC={metrics[6]:.4f}")
|
| 254 |
+
|
| 255 |
+
# Full test set
|
| 256 |
+
preds_full = model.predict(X_test_seq).flatten()
|
| 257 |
+
metrics_full = evaluate_metrics(y_test.values, preds_full)
|
| 258 |
+
all_rows.append([phase_lang, "full"] + list(metrics_full))
|
| 259 |
+
plot_eval_charts(y_test.values, preds_full,
|
| 260 |
+
f"Phase_{phase_lang}_eval_full", fig_dir)
|
| 261 |
+
print(f" eval on {'full':10s} | Acc={metrics_full[0]:.4f} F1={metrics_full[5]:.4f} AUC={metrics_full[6]:.4f}")
|
| 262 |
+
|
| 263 |
+
# ── Full dataset phase ───────────────────────────────────────────────────────
|
| 264 |
+
print(f"\n{'─'*50}")
|
| 265 |
+
print(f"Phase: training on Full dataset ({X_full_shuffled.shape[0]} samples, {EPOCHS} epochs)")
|
| 266 |
+
print(f"{'─'*50}")
|
| 267 |
+
|
| 268 |
+
history_full = model.fit(
|
| 269 |
+
X_full_shuffled, y_full_shuffled,
|
| 270 |
+
validation_data=(X_val_seq, y_val),
|
| 271 |
+
epochs=EPOCHS,
|
| 272 |
+
batch_size=BATCH_FULL,
|
| 273 |
+
verbose=1,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
plot_training_curves(history_full, "Phase_Full", fig_dir)
|
| 277 |
+
|
| 278 |
+
for eval_lang in languages:
|
| 279 |
+
preds = model.predict(lang_test_X[eval_lang]).flatten()
|
| 280 |
+
metrics = evaluate_metrics(lang_test_y[eval_lang], preds)
|
| 281 |
+
all_rows.append(["Full", eval_lang] + list(metrics))
|
| 282 |
+
plot_eval_charts(lang_test_y[eval_lang], preds,
|
| 283 |
+
f"Phase_Full_eval_{eval_lang}", fig_dir)
|
| 284 |
+
print(f" eval on {eval_lang:10s} | Acc={metrics[0]:.4f} F1={metrics[5]:.4f} AUC={metrics[6]:.4f}")
|
| 285 |
+
|
| 286 |
+
preds_full = model.predict(X_test_seq).flatten()
|
| 287 |
+
metrics_full = evaluate_metrics(y_test.values, preds_full)
|
| 288 |
+
all_rows.append(["Full", "full"] + list(metrics_full))
|
| 289 |
+
plot_eval_charts(y_test.values, preds_full, "Phase_Full_eval_full", fig_dir)
|
| 290 |
+
print(f" eval on {'full':10s} | Acc={metrics_full[0]:.4f} F1={metrics_full[5]:.4f} AUC={metrics_full[6]:.4f}")
|
| 291 |
+
|
| 292 |
+
# ── Save results ─────────────────────────────────────────────────────────────
|
| 293 |
+
results_df = pd.DataFrame(all_rows, columns=cols)
|
| 294 |
+
results_df.to_csv(os.path.join(base_path, "results_tables", "hinglish_hindi_english_full_results.csv"), index=False)
|
| 295 |
+
|
| 296 |
+
print("\n" + "=" * 60)
|
| 297 |
+
print("FINAL RESULTS TABLE")
|
| 298 |
+
print("=" * 60)
|
| 299 |
+
print(results_df.to_string(index=False))
|
| 300 |
+
|
| 301 |
+
model.save(os.path.join(base_path, "trained_models", "hinglish_hindi_english_full.h5"))
|
| 302 |
+
print("\nModel saved.")
|
| 303 |
+
print("Done.")
|
predict.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
predict.py — Interactive inference script for the SASC hate speech detection model.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python predict.py # fully interactive
|
| 7 |
+
python predict.py --model model.h5 # specify model path
|
| 8 |
+
python predict.py --input texts.csv # specify input CSV
|
| 9 |
+
python predict.py --text "some text here" # single text prediction
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
|
| 17 |
+
# suppress TF logs
|
| 18 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 19 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
| 20 |
+
|
| 21 |
+
from prompt_toolkit import prompt
|
| 22 |
+
from prompt_toolkit.completion import PathCompleter
|
| 23 |
+
from prompt_toolkit.shortcuts import prompt as pt_prompt
|
| 24 |
+
|
| 25 |
+
path_completer = PathCompleter(expanduser=True)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ── Argument parsing ────────────────────────────────────────────────────────
|
| 29 |
+
parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
|
| 30 |
+
parser.add_argument("--model", type=str, help="Path to .h5 model file")
|
| 31 |
+
parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
|
| 32 |
+
parser.add_argument("--input", type=str, help="Path to input CSV file")
|
| 33 |
+
parser.add_argument("--text", type=str, help="Single text to classify")
|
| 34 |
+
parser.add_argument("--output", type=str, help="Path to save results CSV")
|
| 35 |
+
parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
|
| 36 |
+
parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)")
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ── Interactive prompts if args not provided ─────────────────────────────────
|
| 41 |
+
def ask(message, default=None, is_path=False):
|
| 42 |
+
suffix = f" [{default}]" if default else ""
|
| 43 |
+
if is_path:
|
| 44 |
+
val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip()
|
| 45 |
+
else:
|
| 46 |
+
val = input(f"{message}{suffix}: ").strip()
|
| 47 |
+
val = val if val else default
|
| 48 |
+
return os.path.expanduser(val) if val else val
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
print("\n=== SASC Hate Speech Detector ===\n")
|
| 52 |
+
|
| 53 |
+
# Model path
|
| 54 |
+
model_path = args.model
|
| 55 |
+
if not model_path:
|
| 56 |
+
model_path = ask("Model path (.h5)", "model.h5", is_path=True)
|
| 57 |
+
|
| 58 |
+
if not os.path.exists(model_path):
|
| 59 |
+
print(f"Model not found: {model_path}")
|
| 60 |
+
sys.exit(1)
|
| 61 |
+
|
| 62 |
+
# Tokenizer path
|
| 63 |
+
tokenizer_path = args.tokenizer
|
| 64 |
+
if not tokenizer_path:
|
| 65 |
+
# look next to model file first
|
| 66 |
+
candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
|
| 67 |
+
tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True)
|
| 68 |
+
|
| 69 |
+
if not os.path.exists(tokenizer_path):
|
| 70 |
+
print(f"Tokenizer not found: {tokenizer_path}")
|
| 71 |
+
sys.exit(1)
|
| 72 |
+
|
| 73 |
+
# Threshold
|
| 74 |
+
threshold = args.threshold
|
| 75 |
+
if not args.threshold and not args.text and not args.input:
|
| 76 |
+
t = ask("Decision threshold (0.0-1.0)", "0.5")
|
| 77 |
+
try:
|
| 78 |
+
threshold = float(t)
|
| 79 |
+
except ValueError:
|
| 80 |
+
threshold = 0.5
|
| 81 |
+
|
| 82 |
+
print(f"\nLoading model from {model_path}")
|
| 83 |
+
print(f"Loading tokenizer from {tokenizer_path}")
|
| 84 |
+
import warnings
|
| 85 |
+
warnings.filterwarnings("ignore")
|
| 86 |
+
import tensorflow as tf
|
| 87 |
+
import logging
|
| 88 |
+
tf.get_logger().setLevel(logging.ERROR)
|
| 89 |
+
|
| 90 |
+
model = tf.keras.models.load_model(model_path, compile=False)
|
| 91 |
+
|
| 92 |
+
from tensorflow.keras.preprocessing.text import tokenizer_from_json
|
| 93 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
| 94 |
+
with open(tokenizer_path) as f:
|
| 95 |
+
tokenizer = tokenizer_from_json(f.read())
|
| 96 |
+
|
| 97 |
+
print(f"Model loaded — vocab size: {len(tokenizer.word_index)}")
|
| 98 |
+
|
| 99 |
+
MAX_LEN = 100
|
| 100 |
+
|
| 101 |
+
def predict(texts):
|
| 102 |
+
seqs = tokenizer.texts_to_sequences(texts)
|
| 103 |
+
padded = pad_sequences(seqs, maxlen=MAX_LEN)
|
| 104 |
+
probs = model.predict(padded, verbose=0).flatten()
|
| 105 |
+
labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
|
| 106 |
+
return probs, labels
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# ── Single text mode ──────────────────────────────────────────────────────────
|
| 110 |
+
if args.text:
|
| 111 |
+
probs, labels = predict([args.text])
|
| 112 |
+
print(f"\nText : {args.text}")
|
| 113 |
+
print(f"Label : {labels[0]}")
|
| 114 |
+
print(f"Confidence: {probs[0]:.4f}")
|
| 115 |
+
sys.exit(0)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ── CSV mode ──────────────────────────────────────────────────────────────────
|
| 119 |
+
import pandas as pd
|
| 120 |
+
|
| 121 |
+
input_path = args.input
|
| 122 |
+
if not input_path:
|
| 123 |
+
mode = ask("Input mode — (1) CSV file (2) Type text manually", "1")
|
| 124 |
+
|
| 125 |
+
if mode == "2":
|
| 126 |
+
# manual text entry loop
|
| 127 |
+
print("\nEnter texts one per line. Type 'done' when finished.\n")
|
| 128 |
+
texts = []
|
| 129 |
+
while True:
|
| 130 |
+
t = input(" Text: ").strip()
|
| 131 |
+
if t.lower() == "done":
|
| 132 |
+
break
|
| 133 |
+
if t:
|
| 134 |
+
texts.append(t)
|
| 135 |
+
|
| 136 |
+
if not texts:
|
| 137 |
+
print("No texts entered.")
|
| 138 |
+
sys.exit(0)
|
| 139 |
+
|
| 140 |
+
probs, labels = predict(texts)
|
| 141 |
+
import pandas as pd
|
| 142 |
+
results = pd.DataFrame({
|
| 143 |
+
"text": texts,
|
| 144 |
+
"label": labels,
|
| 145 |
+
"confidence": [round(float(p), 4) for p in probs]
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
print("\n" + "="*60)
|
| 149 |
+
print(results.to_string(index=False))
|
| 150 |
+
print("="*60)
|
| 151 |
+
|
| 152 |
+
out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True)
|
| 153 |
+
if out:
|
| 154 |
+
results.to_csv(out, index=False)
|
| 155 |
+
print(f"Saved to {out}")
|
| 156 |
+
sys.exit(0)
|
| 157 |
+
|
| 158 |
+
else:
|
| 159 |
+
input_path = ask("CSV file path", is_path=True)
|
| 160 |
+
|
| 161 |
+
if not os.path.exists(input_path):
|
| 162 |
+
print(f"File not found: {input_path}")
|
| 163 |
+
sys.exit(1)
|
| 164 |
+
|
| 165 |
+
df = pd.read_csv(input_path)
|
| 166 |
+
print(f"\nLoaded {len(df)} rows from {input_path}")
|
| 167 |
+
print(f"Columns: {list(df.columns)}")
|
| 168 |
+
|
| 169 |
+
text_col = args.col
|
| 170 |
+
if text_col not in df.columns:
|
| 171 |
+
print(f"\nColumn '{text_col}' not found.")
|
| 172 |
+
text_col = ask(f"Which column contains the text?", df.columns[0])
|
| 173 |
+
|
| 174 |
+
print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")
|
| 175 |
+
|
| 176 |
+
texts = df[text_col].fillna("").astype(str).tolist()
|
| 177 |
+
probs, labels = predict(texts)
|
| 178 |
+
|
| 179 |
+
df["predicted_label"] = labels
|
| 180 |
+
df["confidence"] = [round(float(p), 4) for p in probs]
|
| 181 |
+
|
| 182 |
+
# Summary
|
| 183 |
+
hate_count = labels.count("Hate Speech")
|
| 184 |
+
nonhate_count = labels.count("Non-Hate")
|
| 185 |
+
print(f"\n{'='*60}")
|
| 186 |
+
print(f"Results Summary")
|
| 187 |
+
print(f"{'='*60}")
|
| 188 |
+
print(f" Total samples : {len(texts)}")
|
| 189 |
+
print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
|
| 190 |
+
print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
|
| 191 |
+
print(f" Threshold : {threshold}")
|
| 192 |
+
print(f"{'='*60}")
|
| 193 |
+
|
| 194 |
+
# Show sample
|
| 195 |
+
print(f"\nSample predictions (first 10):")
|
| 196 |
+
print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))
|
| 197 |
+
|
| 198 |
+
# Save
|
| 199 |
+
output_path = args.output
|
| 200 |
+
if not output_path:
|
| 201 |
+
default_out = input_path.replace(".csv", "_predictions.csv")
|
| 202 |
+
output_path = ask(f"\nSave full results to CSV", default_out, is_path=True)
|
| 203 |
+
|
| 204 |
+
if output_path:
|
| 205 |
+
df.to_csv(output_path, index=False)
|
| 206 |
+
print(f"\nSaved {len(df)} predictions to {output_path}")
|
pyproject.toml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sasc"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Multilingual Hate Speech Detection — GloVe + BiLSTM with Sequential Transfer Learning"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11,<3.14"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"tensorflow>=2.13.0",
|
| 9 |
+
"numpy>=1.24.0",
|
| 10 |
+
"pandas>=2.0.0",
|
| 11 |
+
"scikit-learn>=1.3.0",
|
| 12 |
+
"matplotlib>=3.7.0",
|
| 13 |
+
"seaborn>=0.12.0",
|
| 14 |
+
"huggingface-hub>=0.20.0",
|
| 15 |
+
"prompt-toolkit>=3.0.52",
|
| 16 |
+
]
|