| |
| """ |
| predict.py β Interactive inference script for the SASC hate speech detection model. |
| |
| Usage: |
| python predict.py # fully interactive |
| python predict.py --model model.h5 # specify model path |
| python predict.py --input texts.csv # specify input CSV |
| python predict.py --text "some text here" # single text prediction |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import json |
|
|
| |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
| os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" |
|
|
| from prompt_toolkit import prompt |
| from prompt_toolkit.completion import PathCompleter |
| from prompt_toolkit.shortcuts import prompt as pt_prompt |
|
|
| path_completer = PathCompleter(expanduser=True) |
|
|
|
|
| |
| parser = argparse.ArgumentParser(description="SASC Hate Speech Detector") |
| parser.add_argument("--model", type=str, help="Path to .h5 model file") |
| parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json") |
| parser.add_argument("--input", type=str, help="Path to input CSV file") |
| parser.add_argument("--text", type=str, help="Single text to classify") |
| parser.add_argument("--output", type=str, help="Path to save results CSV") |
| parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)") |
| parser.add_argument("--col", type=str, default="text", help="Column name in CSV containing text (default: text)") |
| args = parser.parse_args() |
|
|
|
|
| |
| def ask(message, default=None, is_path=False): |
| suffix = f" [{default}]" if default else "" |
| if is_path: |
| val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip() |
| else: |
| val = input(f"{message}{suffix}: ").strip() |
| val = val if val else default |
| return os.path.expanduser(val) if val else val |
|
|
|
|
| print("\n=== SASC Hate Speech Detector ===\n") |
|
|
| |
| model_path = args.model |
| if not model_path: |
| model_path = ask("Model path (.h5)", "model.h5", is_path=True) |
|
|
| if not os.path.exists(model_path): |
| print(f"Model not found: {model_path}") |
| sys.exit(1) |
|
|
| |
| tokenizer_path = args.tokenizer |
| if not tokenizer_path: |
| |
| candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json") |
| tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True) |
|
|
| if not os.path.exists(tokenizer_path): |
| print(f"Tokenizer not found: {tokenizer_path}") |
| sys.exit(1) |
|
|
| |
| threshold = args.threshold |
| if not args.threshold and not args.text and not args.input: |
| t = ask("Decision threshold (0.0-1.0)", "0.5") |
| try: |
| threshold = float(t) |
| except ValueError: |
| threshold = 0.5 |
|
|
| print(f"\nLoading model from {model_path}") |
| print(f"Loading tokenizer from {tokenizer_path}") |
| import warnings |
| warnings.filterwarnings("ignore") |
| import tensorflow as tf |
| import logging |
| tf.get_logger().setLevel(logging.ERROR) |
|
|
| model = tf.keras.models.load_model(model_path, compile=False) |
|
|
| from tensorflow.keras.preprocessing.text import tokenizer_from_json |
| from tensorflow.keras.preprocessing.sequence import pad_sequences |
| with open(tokenizer_path) as f: |
| tokenizer = tokenizer_from_json(f.read()) |
|
|
| print(f"Model loaded β vocab size: {len(tokenizer.word_index)}") |
|
|
| MAX_LEN = 100 |
|
|
| def predict(texts): |
| seqs = tokenizer.texts_to_sequences(texts) |
| padded = pad_sequences(seqs, maxlen=MAX_LEN) |
| probs = model.predict(padded, verbose=0).flatten() |
| labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs] |
| return probs, labels |
|
|
|
|
| |
| if args.text: |
| probs, labels = predict([args.text]) |
| print(f"\nText : {args.text}") |
| print(f"Label : {labels[0]}") |
| print(f"Confidence: {probs[0]:.4f}") |
| sys.exit(0) |
|
|
|
|
| |
| import pandas as pd |
|
|
| input_path = args.input |
| if not input_path: |
| mode = ask("Input mode β (1) CSV file (2) Type text manually", "1") |
|
|
| if mode == "2": |
| |
| print("\nEnter texts one per line. Type 'done' when finished.\n") |
| texts = [] |
| while True: |
| t = input(" Text: ").strip() |
| if t.lower() == "done": |
| break |
| if t: |
| texts.append(t) |
|
|
| if not texts: |
| print("No texts entered.") |
| sys.exit(0) |
|
|
| probs, labels = predict(texts) |
| import pandas as pd |
| results = pd.DataFrame({ |
| "text": texts, |
| "label": labels, |
| "confidence": [round(float(p), 4) for p in probs] |
| }) |
|
|
| print("\n" + "="*60) |
| print(results.to_string(index=False)) |
| print("="*60) |
|
|
| out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True) |
| if out: |
| results.to_csv(out, index=False) |
| print(f"Saved to {out}") |
| sys.exit(0) |
|
|
| else: |
| input_path = ask("CSV file path", is_path=True) |
|
|
| if not os.path.exists(input_path): |
| print(f"File not found: {input_path}") |
| sys.exit(1) |
|
|
| df = pd.read_csv(input_path) |
| print(f"\nLoaded {len(df)} rows from {input_path}") |
| print(f"Columns: {list(df.columns)}") |
|
|
| text_col = args.col |
| if text_col not in df.columns: |
| print(f"\nColumn '{text_col}' not found.") |
| text_col = ask(f"Which column contains the text?", df.columns[0]) |
|
|
| print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...") |
|
|
| texts = df[text_col].fillna("").astype(str).tolist() |
| probs, labels = predict(texts) |
|
|
| df["predicted_label"] = labels |
| df["confidence"] = [round(float(p), 4) for p in probs] |
|
|
| |
| hate_count = labels.count("Hate Speech") |
| nonhate_count = labels.count("Non-Hate") |
| print(f"\n{'='*60}") |
| print(f"Results Summary") |
| print(f"{'='*60}") |
| print(f" Total samples : {len(texts)}") |
| print(f" Hate Speech : {hate_count} ({hate_count/len(texts)*100:.1f}%)") |
| print(f" Non-Hate : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)") |
| print(f" Threshold : {threshold}") |
| print(f"{'='*60}") |
|
|
| |
| print(f"\nSample predictions (first 10):") |
| print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False)) |
|
|
| |
| output_path = args.output |
| if not output_path: |
| default_out = input_path.replace(".csv", "_predictions.csv") |
| output_path = ask(f"\nSave full results to CSV", default_out, is_path=True) |
|
|
| if output_path: |
| df.to_csv(output_path, index=False) |
| print(f"\nSaved {len(df)} predictions to {output_path}") |
|
|