File size: 7,218 Bytes
7e5f759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#!/usr/bin/env python3
"""
predict.py — Interactive inference script for the SASC hate speech detection model.

Usage:
    python predict.py                          # fully interactive
    python predict.py --model model.h5         # specify model path
    python predict.py --input texts.csv        # specify input CSV
    python predict.py --text "some text here"  # single text prediction
"""

import os
import sys
import argparse
import json

# suppress TF logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

from prompt_toolkit import prompt
from prompt_toolkit.completion import PathCompleter
from prompt_toolkit.shortcuts import prompt as pt_prompt

path_completer = PathCompleter(expanduser=True)


# ── Argument parsing ────────────────────────────────────────────────────────
parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
parser.add_argument("--model",     type=str, help="Path to .h5 model file")
parser.add_argument("--tokenizer", type=str, help="Path to tokenizer.json")
parser.add_argument("--input",     type=str, help="Path to input CSV file")
parser.add_argument("--text",      type=str, help="Single text to classify")
parser.add_argument("--output",    type=str, help="Path to save results CSV")
parser.add_argument("--threshold", type=float, default=0.5, help="Decision threshold (default: 0.5)")
parser.add_argument("--col",       type=str, default="text", help="Column name in CSV containing text (default: text)")
args = parser.parse_args()


# ── Interactive prompts if args not provided ─────────────────────────────────
def ask(message, default=None, is_path=False):
    suffix = f" [{default}]" if default else ""
    if is_path:
        val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip()
    else:
        val = input(f"{message}{suffix}: ").strip()
    val = val if val else default
    return os.path.expanduser(val) if val else val


print("\n=== SASC Hate Speech Detector ===\n")

# Model path
model_path = args.model
if not model_path:
    model_path = ask("Model path (.h5)", "model.h5", is_path=True)

if not os.path.exists(model_path):
    print(f"Model not found: {model_path}")
    sys.exit(1)

# Tokenizer path
tokenizer_path = args.tokenizer
if not tokenizer_path:
    # look next to model file first
    candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
    tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True)

if not os.path.exists(tokenizer_path):
    print(f"Tokenizer not found: {tokenizer_path}")
    sys.exit(1)

# Threshold
threshold = args.threshold
if not args.threshold and not args.text and not args.input:
    t = ask("Decision threshold (0.0-1.0)", "0.5")
    try:
        threshold = float(t)
    except ValueError:
        threshold = 0.5

print(f"\nLoading model from   {model_path}")
print(f"Loading tokenizer from {tokenizer_path}")
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)

model = tf.keras.models.load_model(model_path, compile=False)

from tensorflow.keras.preprocessing.text import tokenizer_from_json
from tensorflow.keras.preprocessing.sequence import pad_sequences
with open(tokenizer_path) as f:
    tokenizer = tokenizer_from_json(f.read())

print(f"Model loaded — vocab size: {len(tokenizer.word_index)}")

MAX_LEN = 100

def predict(texts):
    seqs   = tokenizer.texts_to_sequences(texts)
    padded = pad_sequences(seqs, maxlen=MAX_LEN)
    probs  = model.predict(padded, verbose=0).flatten()
    labels = ["Hate Speech" if p > threshold else "Non-Hate" for p in probs]
    return probs, labels


# ── Single text mode ──────────────────────────────────────────────────────────
if args.text:
    probs, labels = predict([args.text])
    print(f"\nText    : {args.text}")
    print(f"Label   : {labels[0]}")
    print(f"Confidence: {probs[0]:.4f}")
    sys.exit(0)


# ── CSV mode ──────────────────────────────────────────────────────────────────
import pandas as pd

input_path = args.input
if not input_path:
    mode = ask("Input mode — (1) CSV file  (2) Type text manually", "1")

    if mode == "2":
        # manual text entry loop
        print("\nEnter texts one per line. Type 'done' when finished.\n")
        texts = []
        while True:
            t = input("  Text: ").strip()
            if t.lower() == "done":
                break
            if t:
                texts.append(t)

        if not texts:
            print("No texts entered.")
            sys.exit(0)

        probs, labels = predict(texts)
        import pandas as pd
        results = pd.DataFrame({
            "text":       texts,
            "label":      labels,
            "confidence": [round(float(p), 4) for p in probs]
        })

        print("\n" + "="*60)
        print(results.to_string(index=False))
        print("="*60)

        out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True)
        if out:
            results.to_csv(out, index=False)
            print(f"Saved to {out}")
        sys.exit(0)

    else:
        input_path = ask("CSV file path", is_path=True)

if not os.path.exists(input_path):
    print(f"File not found: {input_path}")
    sys.exit(1)

df = pd.read_csv(input_path)
print(f"\nLoaded {len(df)} rows from {input_path}")
print(f"Columns: {list(df.columns)}")

text_col = args.col
if text_col not in df.columns:
    print(f"\nColumn '{text_col}' not found.")
    text_col = ask(f"Which column contains the text?", df.columns[0])

print(f"\nRunning inference on column '{text_col}' with threshold={threshold}...")

texts = df[text_col].fillna("").astype(str).tolist()
probs, labels = predict(texts)

df["predicted_label"] = labels
df["confidence"]      = [round(float(p), 4) for p in probs]

# Summary
hate_count    = labels.count("Hate Speech")
nonhate_count = labels.count("Non-Hate")
print(f"\n{'='*60}")
print(f"Results Summary")
print(f"{'='*60}")
print(f"  Total samples : {len(texts)}")
print(f"  Hate Speech   : {hate_count} ({hate_count/len(texts)*100:.1f}%)")
print(f"  Non-Hate      : {nonhate_count} ({nonhate_count/len(texts)*100:.1f}%)")
print(f"  Threshold     : {threshold}")
print(f"{'='*60}")

# Show sample
print(f"\nSample predictions (first 10):")
print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=False))

# Save
output_path = args.output
if not output_path:
    default_out = input_path.replace(".csv", "_predictions.csv")
    output_path = ask(f"\nSave full results to CSV", default_out, is_path=True)

if output_path:
    df.to_csv(output_path, index=False)
    print(f"\nSaved {len(df)} predictions to {output_path}")