tuklu commited on
Commit
47bafb1
Β·
verified Β·
1 Parent(s): 1cf1d84

Add README, tokenizer, results

Browse files
Files changed (2) hide show
  1. main.py +354 -0
  2. predict.py +185 -0
main.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """glove+bilstm.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/10fLw7V6G3vV_STF7KcWe8qcTvyLQq0NT
8
+ """
9
+
10
+ import os
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ import seaborn as sns
15
+ from itertools import permutations
16
+
17
+ # For train-test split and evaluation
18
+ from sklearn.model_selection import train_test_split
19
+ from sklearn.metrics import accuracy_score, balanced_accuracy_score
20
+ from sklearn.metrics import precision_score, recall_score, f1_score
21
+ from sklearn.metrics import roc_auc_score, confusion_matrix
22
+ from sklearn.metrics import roc_curve, precision_recall_curve
23
+
24
+ # Deep learning libraries
25
+ from tensorflow.keras.preprocessing.text import Tokenizer
26
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
27
+ from tensorflow.keras.models import Sequential
28
+ from tensorflow.keras.layers import Embedding, Bidirectional, LSTM
29
+ from tensorflow.keras.layers import Dense, Dropout
30
+
31
+ base_path = "/root/output"
32
+
33
+ os.makedirs(base_path+"/dataset_splits", exist_ok=True)
34
+ os.makedirs(base_path+"/figures", exist_ok=True)
35
+ os.makedirs(base_path+"/results_tables", exist_ok=True)
36
+ os.makedirs(base_path+"/trained_models", exist_ok=True)
37
+
38
+ data_path = "/root/dataset.csv"
39
+
40
+ df = pd.read_csv(data_path)
41
+
42
+ df.head()
43
+
44
+ plt.figure(figsize=(6,4))
45
+ df['language'].value_counts().plot.pie(autopct='%1.1f%%')
46
+ plt.title("Dataset Language Distribution")
47
+ plt.ylabel("")
48
+ plt.savefig(base_path+"/figures/language_distribution.png", dpi=300)
49
+ plt.show()
50
+
51
+ X = df["clean_text"]
52
+ y = df["hate_label"]
53
+ lang = df["language"]
54
+
55
+ X_temp, X_test, y_temp, y_test, lang_temp, lang_test = train_test_split(
56
+ X, y, lang, test_size=0.30, stratify=y, random_state=42)
57
+
58
+ X_train, X_val, y_train, y_val, lang_train, lang_val = train_test_split(
59
+ X_temp, y_temp, lang_temp,
60
+ test_size=0.1428,
61
+ stratify=y_temp,
62
+ random_state=42
63
+ )
64
+
65
+ pd.DataFrame({"text":X_train,"label":y_train,"lang":lang_train}).to_csv(
66
+ base_path+"/dataset_splits/train.csv", index=False)
67
+
68
+ pd.DataFrame({"text":X_val,"label":y_val,"lang":lang_val}).to_csv(
69
+ base_path+"/dataset_splits/val.csv", index=False)
70
+
71
+ pd.DataFrame({"text":X_test,"label":y_test,"lang":lang_test}).to_csv(
72
+ base_path+"/dataset_splits/test.csv", index=False)
73
+
74
+ MAX_LEN = 100
75
+ VOCAB = 50000
76
+
77
+ tokenizer = Tokenizer(num_words=VOCAB)
78
+ tokenizer.fit_on_texts(X_train)
79
+
80
+ X_train_seq = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=MAX_LEN)
81
+ X_val_seq = pad_sequences(tokenizer.texts_to_sequences(X_val), maxlen=MAX_LEN)
82
+ X_test_seq = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=MAX_LEN)
83
+
84
+ EMBEDDING_DIM = 300
85
+ glove_path = "/root/glove.6B.300d.txt"
86
+
87
+ embeddings_index = {}
88
+
89
+ with open(glove_path, encoding="utf8") as f:
90
+ for line in f:
91
+ values = line.split()
92
+ word = values[0]
93
+ vector = np.asarray(values[1:], dtype="float32")
94
+ embeddings_index[word] = vector
95
+
96
+ print("Loaded %s word vectors." % len(embeddings_index))
97
+
98
+ word_index = tokenizer.word_index
99
+ embedding_dim = 300
100
+
101
+ embedding_matrix = np.zeros((len(word_index)+1, embedding_dim))
102
+
103
+ for word, i in word_index.items():
104
+ vector = embeddings_index.get(word)
105
+ if vector is not None:
106
+ embedding_matrix[i] = vector
107
+
108
+
109
+ # ============================================================
110
+ # Helper: build a fresh model (called once per permutation)
111
+ # ============================================================
112
+ def build_model():
113
+ """Construct and compile a fresh BiLSTM model with frozen GloVe embeddings."""
114
+ m = Sequential()
115
+ m.add(Embedding(
116
+ input_dim=len(word_index)+1,
117
+ output_dim=embedding_dim,
118
+ weights=[embedding_matrix],
119
+ input_length=MAX_LEN,
120
+ trainable=False
121
+ ))
122
+ m.add(Bidirectional(LSTM(128)))
123
+ m.add(Dropout(0.5))
124
+ m.add(Dense(64, activation="relu"))
125
+ m.add(Dense(1, activation="sigmoid"))
126
+ m.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
127
+ return m
128
+
129
+
130
+ def evaluate_metrics(y_true, y_pred_prob):
131
+ y_pred = (y_pred_prob > 0.5).astype(int)
132
+ acc = accuracy_score(y_true, y_pred)
133
+ bal = balanced_accuracy_score(y_true, y_pred)
134
+ prec = precision_score(y_true, y_pred)
135
+ rec = recall_score(y_true, y_pred)
136
+ f1 = f1_score(y_true, y_pred)
137
+ auc = roc_auc_score(y_true, y_pred_prob)
138
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
139
+ spec = tn / (tn + fp)
140
+ return acc, bal, prec, rec, spec, f1, auc
141
+
142
+
143
+ def plot_training_curves(history, tag, base_path):
144
+ """Save accuracy and loss curves for one training phase."""
145
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
146
+
147
+ axes[0].plot(history.history['accuracy'], label="Train Accuracy")
148
+ axes[0].plot(history.history['val_accuracy'], label="Val Accuracy")
149
+ axes[0].set_title(f"{tag} - Accuracy Curve")
150
+ axes[0].set_xlabel("Epoch")
151
+ axes[0].set_ylabel("Accuracy")
152
+ axes[0].legend()
153
+ axes[0].grid(True)
154
+
155
+ axes[1].plot(history.history['loss'], label="Train Loss")
156
+ axes[1].plot(history.history['val_loss'], label="Val Loss")
157
+ axes[1].set_title(f"{tag} - Loss Curve")
158
+ axes[1].set_xlabel("Epoch")
159
+ axes[1].set_ylabel("Loss")
160
+ axes[1].legend()
161
+ axes[1].grid(True)
162
+
163
+ plt.tight_layout()
164
+ fname = tag.replace(" -> ", "_to_").replace(" ", "_")
165
+ plt.savefig(os.path.join(base_path, f"{fname}_curves.png"), dpi=300)
166
+ plt.show()
167
+
168
+
169
+ def plot_eval_charts(y_test, preds, tag, base_path):
170
+ """Save confusion matrix, ROC, PR, and F1 curves after evaluation."""
171
+ fname = tag.replace(" -> ", "_to_").replace(" ", "_")
172
+
173
+ # Confusion Matrix
174
+ cm = confusion_matrix(y_test, (preds > 0.5).astype(int))
175
+ plt.figure(figsize=(6,4))
176
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
177
+ xticklabels=["Non-Hate","Hate"],
178
+ yticklabels=["Non-Hate","Hate"])
179
+ plt.title(f"{tag} - Confusion Matrix")
180
+ plt.xlabel("Predicted")
181
+ plt.ylabel("Actual")
182
+ plt.savefig(os.path.join(base_path, f"{fname}_cm.png"), dpi=300)
183
+ plt.show()
184
+
185
+ # ROC Curve
186
+ fpr, tpr, _ = roc_curve(y_test, preds)
187
+ auc_val = roc_auc_score(y_test, preds)
188
+ plt.figure(figsize=(6,4))
189
+ plt.plot(fpr, tpr, label=f"AUC={auc_val:.4f}")
190
+ plt.plot([0,1],[0,1],'--')
191
+ plt.title(f"{tag} - ROC Curve")
192
+ plt.xlabel("FPR")
193
+ plt.ylabel("TPR")
194
+ plt.legend()
195
+ plt.grid(True)
196
+ plt.savefig(os.path.join(base_path, f"{fname}_roc.png"), dpi=300)
197
+ plt.show()
198
+
199
+ # Precision-Recall Curve
200
+ precision, recall, thresholds = precision_recall_curve(y_test, preds)
201
+ plt.figure(figsize=(6,4))
202
+ plt.plot(recall, precision)
203
+ plt.title(f"{tag} - Precision-Recall Curve")
204
+ plt.xlabel("Recall")
205
+ plt.ylabel("Precision")
206
+ plt.grid(True)
207
+ plt.savefig(os.path.join(base_path, f"{fname}_pr.png"), dpi=300)
208
+ plt.show()
209
+
210
+ # F1 Curve
211
+ f1_scores = (2 * precision * recall) / (precision + recall + 1e-8)
212
+ plt.figure(figsize=(6,4))
213
+ plt.plot(thresholds, f1_scores[:-1])
214
+ plt.title(f"{tag} - F1 Score vs Threshold")
215
+ plt.xlabel("Threshold")
216
+ plt.ylabel("F1 Score")
217
+ plt.grid(True)
218
+ plt.savefig(os.path.join(base_path, f"{fname}_f1.png"), dpi=300)
219
+ plt.show()
220
+
221
+
222
+ # ============================================================
223
+ # PLAN B: All 6 permutations + final Full (Shuffled) fine-tune
224
+ # After each training phase β†’ evaluate on that language's test set
225
+ # After Full phase β†’ evaluate on full test set
226
+ # ============================================================
227
+ print("\n" + "="*60)
228
+ print("PLAN B: Sequential Transfer Learning + Full Dataset Fine-tune")
229
+ print("="*60)
230
+
231
+ languages = ["english", "hindi", "hinglish"]
232
+
233
+ # Pre-shuffle full training data once (same shuffle for all permutations)
234
+ np.random.seed(42)
235
+ shuffle_idx = np.random.permutation(len(X_train_seq))
236
+ X_full_shuffled = np.ascontiguousarray(X_train_seq[shuffle_idx], dtype=np.int32)
237
+ y_full_shuffled = np.ascontiguousarray(y_train.values[shuffle_idx], dtype=np.float32)
238
+
239
+ # Pre-build per-language test splits
240
+ lang_test_idx = {
241
+ lang: (lang_test.values == lang)
242
+ for lang in languages
243
+ }
244
+ lang_test_X = {
245
+ lang: X_test_seq[lang_test_idx[lang]]
246
+ for lang in languages
247
+ }
248
+ lang_test_y = {
249
+ lang: y_test.values[lang_test_idx[lang]]
250
+ for lang in languages
251
+ }
252
+
253
+ cols = ["Strategy", "Phase", "Accuracy", "Balanced Acc",
254
+ "Precision", "Recall", "Specificity", "F1", "ROC-AUC"]
255
+
256
+ for perm in permutations(languages):
257
+ perm_name = " -> ".join(perm)
258
+ strategy_name = perm_name + " -> Full"
259
+ strategy_results = []
260
+
261
+ print(f"\n{'='*50}")
262
+ print(f"Strategy: {strategy_name}")
263
+ print(f"{'='*50}")
264
+
265
+ # Make a clean folder per strategy for figures
266
+ strat_tag = perm_name.replace(" -> ", "_to_")
267
+ strat_fig_path = base_path + f"/figures/{strat_tag}"
268
+ os.makedirs(strat_fig_path, exist_ok=True)
269
+
270
+ # Model built ONCE β€” weights carry forward across all phases
271
+ model = build_model()
272
+
273
+ # ── Language phases ──────────────────────────────────────
274
+ for lang in perm:
275
+ idx = (lang_train == lang)
276
+ X_lang = X_train_seq[idx]
277
+ y_lang = y_train[idx]
278
+
279
+ print(f" Training on: {lang} ({X_lang.shape[0]} samples)")
280
+
281
+ history = model.fit(
282
+ X_lang, y_lang,
283
+ validation_data=(X_val_seq, y_val),
284
+ epochs=8,
285
+ batch_size=32,
286
+ verbose=1
287
+ )
288
+
289
+ # Train/Val accuracy + loss curves
290
+ plot_training_curves(history, f"{strat_tag} [{lang}]", strat_fig_path)
291
+
292
+ # Evaluate on this language's test subset
293
+ preds = model.predict(lang_test_X[lang]).flatten()
294
+ acc, bal, prec, rec, spec, f1, auc = evaluate_metrics(lang_test_y[lang], preds)
295
+ strategy_results.append([strategy_name, lang, acc, bal, prec, rec, spec, f1, auc])
296
+
297
+ # Eval plots for this language
298
+ plot_eval_charts(lang_test_y[lang], preds, f"{strat_tag} [{lang}]", strat_fig_path)
299
+
300
+ print(f" Acc={acc:.4f} F1={f1:.4f} AUC={auc:.4f}")
301
+
302
+ # ── Full phase ───────────────────────────────────────────
303
+ print(f" Training on: Full Dataset ({X_full_shuffled.shape[0]} samples, shuffled)")
304
+
305
+ history_full = model.fit(
306
+ X_full_shuffled, y_full_shuffled,
307
+ validation_data=(X_val_seq, y_val),
308
+ epochs=8,
309
+ batch_size=64,
310
+ verbose=1
311
+ )
312
+
313
+ # Train/Val accuracy + loss curves for full phase
314
+ plot_training_curves(history_full, f"{strat_tag} [Full]", strat_fig_path)
315
+
316
+ # Evaluate on full test set
317
+ preds_full = model.predict(X_test_seq).flatten()
318
+ acc, bal, prec, rec, spec, f1, auc = evaluate_metrics(y_test.values, preds_full)
319
+ strategy_results.append([strategy_name, "Full", acc, bal, prec, rec, spec, f1, auc])
320
+
321
+ # Eval plots for full phase
322
+ plot_eval_charts(y_test.values, preds_full, f"{strat_tag} [Full]", strat_fig_path)
323
+
324
+ print(f" Acc={acc:.4f} F1={f1:.4f} AUC={auc:.4f}")
325
+
326
+ # Save per-strategy results table (4 rows: 3 langs + Full)
327
+ strat_df = pd.DataFrame(strategy_results, columns=cols)
328
+ strat_df.to_csv(
329
+ base_path + f"/results_tables/{strat_tag}_results.csv",
330
+ index=False
331
+ )
332
+
333
+ print(f"\n Results for strategy: {strategy_name}")
334
+ print(strat_df.to_string(index=False))
335
+
336
+ model.save(base_path + f"/trained_models/planB_{strat_tag}_Full.h5")
337
+ print(f" Saved model: planB_{strat_tag}_Full.h5")
338
+
339
+
340
+ # ============================================================
341
+ # COMBINED RESULTS TABLE (all 6 strategies Γ— 4 phases = 24 rows)
342
+ # ============================================================
343
+ all_csv = [
344
+ base_path + f"/results_tables/{('_to_'.join(perm))}_results.csv"
345
+ for perm in permutations(languages)
346
+ ]
347
+
348
+ combined_df = pd.concat([pd.read_csv(f) for f in all_csv], ignore_index=True)
349
+ combined_df.to_csv(base_path + "/results_tables/all_strategies_results.csv", index=False)
350
+
351
+ print("\n" + "="*60)
352
+ print("ALL STRATEGIES β€” COMBINED RESULTS")
353
+ print("="*60)
354
+ print(combined_df.to_string(index=False))
predict.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ predict.py β€” Interactive inference script for the SASC hate speech detection model.
4
+
5
+ Usage:
6
+ python predict.py # fully interactive
7
+ python predict.py --model model.h5 # specify model path
8
+ python predict.py --input texts.csv # specify input CSV
9
+ python predict.py --text "some text here" # single text prediction
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import argparse
15
+ import json
16
+
17
+
18
+ # ── Argument parsing ────────────────────────────────────────────────────────
19
+ parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
20
+ parser.add_argument("--model", type=str, help="Path to .h5 model file")
21
+ parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
22
+ parser.add_argument("--input", type=str, help="Path to input CSV file")
23
+ parser.add_argument("--text", type=str, help="Single text to classify")
24
+ parser.add_argument("--output", type=str, help="Path to save results CSV")
25
+ parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
26
+ parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)")
27
+ args = parser.parse_args()
28
+
29
+
30
+ # ── Interactive prompts if args not provided ─────────────────────────────────
31
+ def ask(prompt, default=None):
32
+ suffix = f" [{default}]" if default else ""
33
+ val = input(f"{prompt}{suffix}: ").strip()
34
+ return val if val else default
35
+
36
+
37
+ print("\n=== SASC Hate Speech Detector ===\n")
38
+
39
+ # Model path
40
+ model_path = args.model
41
+ if not model_path:
42
+ model_path = ask("Model path (.h5)", "model.h5")
43
+
44
+ if not os.path.exists(model_path):
45
+ print(f"Model not found: {model_path}")
46
+ sys.exit(1)
47
+
48
+ # Tokenizer path
49
+ tokenizer_path = args.tokenizer
50
+ if not tokenizer_path:
51
+ # look next to model file first
52
+ candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
53
+ tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json")
54
+
55
+ if not os.path.exists(tokenizer_path):
56
+ print(f"Tokenizer not found: {tokenizer_path}")
57
+ sys.exit(1)
58
+
59
+ # Threshold
60
+ threshold = args.threshold
61
+ if not args.threshold and not args.text and not args.input:
62
+ t = ask("Decision threshold (0.0-1.0)", "0.5")
63
+ try:
64
+ threshold = float(t)
65
+ except ValueError:
66
+ threshold = 0.5
67
+
68
+ print(f"\nLoading model from {model_path}...")
69
+ import tensorflow as tf
70
+ model = tf.keras.models.load_model(model_path)
71
+
72
+ print(f"Loading tokenizer from {tokenizer_path}...")
73
+ from tensorflow.keras.preprocessing.text import tokenizer_from_json
74
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
75
+ with open(tokenizer_path) as f:
76
+ tokenizer = tokenizer_from_json(f.read())
77
+
78
+ MAX_LEN = 100
79
+
80
+ def predict(texts):
81
+ seqs = tokenizer.texts_to_sequences(texts)
82
+ padded = pad_sequences(seqs, maxlen=MAX_LEN)
83
+ probs = model.predict(padded, verbose=0).flatten()
84
+ labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
85
+ return probs, labels
86
+
87
+
88
+ # ── Single text mode ──────────────────────────────────────────────────────────
89
+ if args.text:
90
+ probs, labels = predict([args.text])
91
+ print(f"\nText : {args.text}")
92
+ print(f"Label : {labels[0]}")
93
+ print(f"Confidence: {probs[0]:.4f}")
94
+ sys.exit(0)
95
+
96
+
97
+ # ── CSV mode ──────────────────────────────────────────────────────────────────
98
+ import pandas as pd
99
+
100
+ input_path = args.input
101
+ if not input_path:
102
+ mode = ask("Input mode β€” (1) CSV file (2) Type text manually", "1")
103
+
104
+ if mode == "2":
105
+ # manual text entry loop
106
+ print("\nEnter texts one per line. Type 'done' when finished.\n")
107
+ texts = []
108
+ while True:
109
+ t = input(" Text: ").strip()
110
+ if t.lower() == "done":
111
+ break
112
+ if t:
113
+ texts.append(t)
114
+
115
+ if not texts:
116
+ print("No texts entered.")
117
+ sys.exit(0)
118
+
119
+ probs, labels = predict(texts)
120
+ import pandas as pd
121
+ results = pd.DataFrame({
122
+ "text": texts,
123
+ "label": labels,
124
+ "confidence": [round(float(p), 4) for p in probs]
125
+ })
126
+
127
+ print("\n" + "="*60)
128
+ print(results.to_string(index=False))
129
+ print("="*60)
130
+
131
+ out = args.output or ask("Save results to CSV? (leave blank to skip)", "")
132
+ if out:
133
+ results.to_csv(out, index=False)
134
+ print(f"Saved to {out}")
135
+ sys.exit(0)
136
+
137
+ else:
138
+ input_path = ask("CSV file path")
139
+
140
+ if not os.path.exists(input_path):
141
+ print(f"File not found: {input_path}")
142
+ sys.exit(1)
143
+
144
+ df = pd.read_csv(input_path)
145
+ print(f"\nLoaded {len(df)} rows from {input_path}")
146
+ print(f"Columns: {list(df.columns)}")
147
+
148
+ text_col = args.col
149
+ if text_col not in df.columns:
150
+ print(f"\nColumn '{text_col}' not found.")
151
+ text_col = ask(f"Which column contains the text?", df.columns[0])
152
+
153
+ print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")
154
+
155
+ texts = df[text_col].fillna("").astype(str).tolist()
156
+ probs, labels = predict(texts)
157
+
158
+ df["predicted_label"] = labels
159
+ df["confidence"] = [round(float(p), 4) for p in probs]
160
+
161
+ # Summary
162
+ hate_count = labels.count("Hate Speech")
163
+ nonhate_count = labels.count("Non-Hate")
164
+ print(f"\n{'='*60}")
165
+ print(f"Results Summary")
166
+ print(f"{'='*60}")
167
+ print(f" Total samples : {len(texts)}")
168
+ print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
169
+ print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
170
+ print(f" Threshold : {threshold}")
171
+ print(f"{'='*60}")
172
+
173
+ # Show sample
174
+ print(f"\nSample predictions (first 10):")
175
+ print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))
176
+
177
+ # Save
178
+ output_path = args.output
179
+ if not output_path:
180
+ default_out = input_path.replace(".csv", "_predictions.csv")
181
+ output_path = ask(f"\nSave full results to CSV", default_out)
182
+
183
+ if output_path:
184
+ df.to_csv(output_path, index=False)
185
+ print(f"\nSaved {len(df)} predictions to {output_path}")