File size: 7,218 Bytes
7e5f759 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | #!/usr/bin/env python3
"""
predict.py — Interactive inference script for the SASC hate speech detection model.
Usage:
python predict.py # fully interactive
python predict.py --model model.h5 # specify model path
python predict.py --input texts.csv # specify input CSV
python predict.py --text "some text here" # single text prediction
"""
import os
import sys
import argparse
import json
# suppress TF logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from prompt_toolkit import prompt
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.shortcuts import prompt as pt_prompt
path_completer = PathCompleter(expanduser=True)
# ── Argument parsing ────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
parser.add_argument("--model", type=str, help="Path to .h5 model file")
parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
parser.add_argument("--input", type=str, help="Path to input CSV file")
parser.add_argument("--text", type=str, help="Single text to classify")
parser.add_argument("--output", type=str, help="Path to save results CSV")
parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)")
args = parser.parse_args()
# ── Interactive prompts if args not provided ─────────────────────────────────
def ask(message, default=None, is_path=False):
suffix = f" [{default}]" if default else ""
if is_path:
val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip()
else:
val = input(f"{message}{suffix}: ").strip()
val = val if val else default
return os.path.expanduser(val) if val else val
print("\n=== SASC Hate Speech Detector ===\n")
# Model path
model_path = args.model
if not model_path:
model_path = ask("Model path (.h5)", "model.h5", is_path=True)
if not os.path.exists(model_path):
print(f"Model not found: {model_path}")
sys.exit(1)
# Tokenizer path
tokenizer_path = args.tokenizer
if not tokenizer_path:
# look next to model file first
candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True)
if not os.path.exists(tokenizer_path):
print(f"Tokenizer not found: {tokenizer_path}")
sys.exit(1)
# Threshold
threshold = args.threshold
if not args.threshold and not args.text and not args.input:
t = ask("Decision threshold (0.0-1.0)", "0.5")
try:
threshold = float(t)
except ValueError:
threshold = 0.5
print(f"\nLoading model from {model_path}")
print(f"Loading tokenizer from {tokenizer_path}")
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)
model = tf.keras.models.load_model(model_path, compile=False)
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
with open(tokenizer_path) as f:
tokenizer = tokenizer_from_json(f.read())
print(f"Model loaded — vocab size: {len(tokenizer.word_index)}")
MAX_LEN = 100
def predict(texts):
seqs = tokenizer.texts_to_sequences(texts)
padded = pad_sequences(seqs, maxlen=MAX_LEN)
probs = model.predict(padded, verbose=0).flatten()
labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
return probs, labels
# ── Single text mode ──────────────────────────────────────────────────────────
if args.text:
probs, labels = predict([args.text])
print(f"\nText : {args.text}")
print(f"Label : {labels[0]}")
print(f"Confidence: {probs[0]:.4f}")
sys.exit(0)
# ── CSV mode ──────────────────────────────────────────────────────────────────
import pandas as pd
input_path = args.input
if not input_path:
mode = ask("Input mode — (1) CSV file (2) Type text manually", "1")
if mode == "2":
# manual text entry loop
print("\nEnter texts one per line. Type 'done' when finished.\n")
texts = []
while True:
t = input(" Text: ").strip()
if t.lower() == "done":
break
if t:
texts.append(t)
if not texts:
print("No texts entered.")
sys.exit(0)
probs, labels = predict(texts)
import pandas as pd
results = pd.DataFrame({
"text": texts,
"label": labels,
"confidence": [round(float(p), 4) for p in probs]
})
print("\n" + "="*60)
print(results.to_string(index=False))
print("="*60)
out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True)
if out:
results.to_csv(out, index=False)
print(f"Saved to {out}")
sys.exit(0)
else:
input_path = ask("CSV file path", is_path=True)
if not os.path.exists(input_path):
print(f"File not found: {input_path}")
sys.exit(1)
df = pd.read_csv(input_path)
print(f"\nLoaded {len(df)} rows from {input_path}")
print(f"Columns: {list(df.columns)}")
text_col = args.col
if text_col not in df.columns:
print(f"\nColumn '{text_col}' not found.")
text_col = ask(f"Which column contains the text?", df.columns[0])
print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")
texts = df[text_col].fillna("").astype(str).tolist()
probs, labels = predict(texts)
df["predicted_label"] = labels
df["confidence"] = [round(float(p), 4) for p in probs]
# Summary
hate_count = labels.count("Hate Speech")
nonhate_count = labels.count("Non-Hate")
print(f"\n{'='*60}")
print(f"Results Summary")
print(f"{'='*60}")
print(f" Total samples : {len(texts)}")
print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
print(f" Threshold : {threshold}")
print(f"{'='*60}")
# Show sample
print(f"\nSample predictions (first 10):")
print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))
# Save
output_path = args.output
if not output_path:
default_out = input_path.replace(".csv", "_predictions.csv")
output_path = ask(f"\nSave full results to CSV", default_out, is_path=True)
if output_path:
df.to_csv(output_path, index=False)
print(f"\nSaved {len(df)} predictions to {output_path}")
|