#!/usr/bin/env python3 """ predict.py — Interactive inference script for the SASC hate speech detection model. Usage: python predict.py # fully interactive python predict.py --model model.h5 # specify model path python predict.py --input texts.csv # specify input CSV python predict.py --text "some text here" # single text prediction """ import os import sys import argparse import json # suppress TF logs os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" from prompt_toolkit import prompt from prompt_toolkit.completion import PathCompleter from prompt_toolkit.shortcuts import prompt as pt_prompt path_completer = PathCompleter(expanduser=True) # ── Argument parsing ──────────────────────────────────────────────────────── parser = argparse.ArgumentParser(description="SASC Hate Speech Detector") parser.add_argument("--model", type=str, help="Path to .h5 model file") parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json") parser.add_argument("--input", type=str, help="Path to input CSV file") parser.add_argument("--text", type=str, help="Single text to classify") parser.add_argument("--output", type=str, help="Path to save results CSV") parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)") parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)") args = parser.parse_args() # ── Interactive prompts if args not provided ───────────────────────────────── def ask(message, default=None, is_path=False): suffix = f" [{default}]" if default else "" if is_path: val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip() else: val = input(f"{message}{suffix}: ").strip() val = val if val else default return os.path.expanduser(val) if val else val print("\n=== SASC Hate Speech Detector ===\n") # Model path model_path = args.model if not model_path: model_path = ask("Model path (.h5)", "model.h5", is_path=True) if not os.path.exists(model_path): print(f"Model not found: {model_path}") sys.exit(1) # Tokenizer path tokenizer_path = args.tokenizer if not tokenizer_path: # look next to model file first candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json") tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True) if not os.path.exists(tokenizer_path): print(f"Tokenizer not found: {tokenizer_path}") sys.exit(1) # Threshold threshold = args.threshold if not args.threshold and not args.text and not args.input: t = ask("Decision threshold (0.0-1.0)", "0.5") try: threshold = float(t) except ValueError: threshold = 0.5 print(f"\nLoading model from {model_path}") print(f"Loading tokenizer from {tokenizer_path}") import warnings warnings.filterwarnings("ignore") import tensorflow as tf import logging tf.get_logger().setLevel(logging.ERROR) model = tf.keras.models.load_model(model_path, compile=False) from tensorflow.keras.preprocessing.text import tokenizer_from_json from tensorflow.keras.preprocessing.sequence import pad_sequences with open(tokenizer_path) as f: tokenizer = tokenizer_from_json(f.read()) print(f"Model loaded — vocab size: {len(tokenizer.word_index)}") MAX_LEN = 100 def predict(texts): seqs = tokenizer.texts_to_sequences(texts) padded = pad_sequences(seqs, maxlen=MAX_LEN) probs = model.predict(padded, verbose=0).flatten() labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs] return probs, labels # ── Single text mode ────────────────────────────────────────────────────────── if args.text: probs, labels = predict([args.text]) print(f"\nText : {args.text}") print(f"Label : {labels[0]}") print(f"Confidence: {probs[0]:.4f}") sys.exit(0) # ── CSV mode ────────────────────────────────────────────────────────────────── import pandas as pd input_path = args.input if not input_path: mode = ask("Input mode — (1) CSV file (2) Type text manually", "1") if mode == "2": # manual text entry loop print("\nEnter texts one per line. Type 'done' when finished.\n") texts = [] while True: t = input(" Text: ").strip() if t.lower() == "done": break if t: texts.append(t) if not texts: print("No texts entered.") sys.exit(0) probs, labels = predict(texts) import pandas as pd results = pd.DataFrame({ "text": texts, "label": labels, "confidence": [round(float(p), 4) for p in probs] }) print("\n" + "="*60) print(results.to_string(index=False)) print("="*60) out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True) if out: results.to_csv(out, index=False) print(f"Saved to {out}") sys.exit(0) else: input_path = ask("CSV file path", is_path=True) if not os.path.exists(input_path): print(f"File not found: {input_path}") sys.exit(1) df = pd.read_csv(input_path) print(f"\nLoaded {len(df)} rows from {input_path}") print(f"Columns: {list(df.columns)}") text_col = args.col if text_col not in df.columns: print(f"\nColumn '{text_col}' not found.") text_col = ask(f"Which column contains the text?", df.columns[0]) print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...") texts = df[text_col].fillna("").astype(str).tolist() probs, labels = predict(texts) df["predicted_label"] = labels df["confidence"] = [round(float(p), 4) for p in probs] # Summary hate_count = labels.count("Hate Speech") nonhate_count = labels.count("Non-Hate") print(f"\n{'='*60}") print(f"Results Summary") print(f"{'='*60}") print(f" Total samples : {len(texts)}") print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)") print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)") print(f" Threshold : {threshold}") print(f"{'='*60}") # Show sample print(f"\nSample predictions (first 10):") print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False)) # Save output_path = args.output if not output_path: default_out = input_path.replace(".csv", "_predictions.csv") output_path = ask(f"\nSave full results to CSV", default_out, is_path=True) if output_path: df.to_csv(output_path, index=False) print(f"\nSaved {len(df)} predictions to {output_path}")