SASC / predict.py
tuklu's picture
Add README, tokenizer, results
46da6a8 verified
#!/usr/bin/env python3
"""
predict.py β€” Interactive inference script for the SASC hate speech detection model.
Usage:
python predict.py # fully interactive
python predict.py --model model.h5 # specify model path
python predict.py --input texts.csv # specify input CSV
python predict.py --text "some text here" # single text prediction
"""
import os
import sys
import argparse
import json
# suppress TF logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
from prompt_toolkit import prompt
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.shortcuts import prompt as pt_prompt
path_completer = PathCompleter(expanduser=True)
# ── Argument parsing ────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
parser.add_argument("--model", type=str, help="Path to .h5 model file")
parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
parser.add_argument("--input", type=str, help="Path to input CSV file")
parser.add_argument("--text", type=str, help="Single text to classify")
parser.add_argument("--output", type=str, help="Path to save results CSV")
parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)")
args = parser.parse_args()
# ── Interactive prompts if args not provided ─────────────────────────────────
def ask(message, default=None, is_path=False):
suffix = f" [{default}]" if default else ""
if is_path:
val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip()
else:
val = input(f"{message}{suffix}: ").strip()
val = val if val else default
return os.path.expanduser(val) if val else val
print("\n=== SASC Hate Speech Detector ===\n")
# Model path
model_path = args.model
if not model_path:
model_path = ask("Model path (.h5)", "model.h5", is_path=True)
if not os.path.exists(model_path):
print(f"Model not found: {model_path}")
sys.exit(1)
# Tokenizer path
tokenizer_path = args.tokenizer
if not tokenizer_path:
# look next to model file first
candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True)
if not os.path.exists(tokenizer_path):
print(f"Tokenizer not found: {tokenizer_path}")
sys.exit(1)
# Threshold
threshold = args.threshold
if not args.threshold and not args.text and not args.input:
t = ask("Decision threshold (0.0-1.0)", "0.5")
try:
threshold = float(t)
except ValueError:
threshold = 0.5
print(f"\nLoading model from {model_path}")
print(f"Loading tokenizer from {tokenizer_path}")
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)
model = tf.keras.models.load_model(model_path, compile=False)
from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
with open(tokenizer_path) as f:
tokenizer = tokenizer_from_json(f.read())
print(f"Model loaded β€” vocab size: {len(tokenizer.word_index)}")
MAX_LEN = 100
def predict(texts):
seqs = tokenizer.texts_to_sequences(texts)
padded = pad_sequences(seqs, maxlen=MAX_LEN)
probs = model.predict(padded, verbose=0).flatten()
labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
return probs, labels
# ── Single text mode ──────────────────────────────────────────────────────────
if args.text:
probs, labels = predict([args.text])
print(f"\nText : {args.text}")
print(f"Label : {labels[0]}")
print(f"Confidence: {probs[0]:.4f}")
sys.exit(0)
# ── CSV mode ──────────────────────────────────────────────────────────────────
import pandas as pd
input_path = args.input
if not input_path:
mode = ask("Input mode β€” (1) CSV file (2) Type text manually", "1")
if mode == "2":
# manual text entry loop
print("\nEnter texts one per line. Type 'done' when finished.\n")
texts = []
while True:
t = input(" Text: ").strip()
if t.lower() == "done":
break
if t:
texts.append(t)
if not texts:
print("No texts entered.")
sys.exit(0)
probs, labels = predict(texts)
import pandas as pd
results = pd.DataFrame({
"text": texts,
"label": labels,
"confidence": [round(float(p), 4) for p in probs]
})
print("\n" + "="*60)
print(results.to_string(index=False))
print("="*60)
out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True)
if out:
results.to_csv(out, index=False)
print(f"Saved to {out}")
sys.exit(0)
else:
input_path = ask("CSV file path", is_path=True)
if not os.path.exists(input_path):
print(f"File not found: {input_path}")
sys.exit(1)
df = pd.read_csv(input_path)
print(f"\nLoaded {len(df)} rows from {input_path}")
print(f"Columns: {list(df.columns)}")
text_col = args.col
if text_col not in df.columns:
print(f"\nColumn '{text_col}' not found.")
text_col = ask(f"Which column contains the text?", df.columns[0])
print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")
texts = df[text_col].fillna("").astype(str).tolist()
probs, labels = predict(texts)
df["predicted_label"] = labels
df["confidence"] = [round(float(p), 4) for p in probs]
# Summary
hate_count = labels.count("Hate Speech")
nonhate_count = labels.count("Non-Hate")
print(f"\n{'='*60}")
print(f"Results Summary")
print(f"{'='*60}")
print(f" Total samples : {len(texts)}")
print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
print(f" Threshold : {threshold}")
print(f"{'='*60}")
# Show sample
print(f"\nSample predictions (first 10):")
print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))
# Save
output_path = args.output
if not output_path:
default_out = input_path.replace(".csv", "_predictions.csv")
output_path = ask(f"\nSave full results to CSV", default_out, is_path=True)
if output_path:
df.to_csv(output_path, index=False)
print(f"\nSaved {len(df)} predictions to {output_path}")