tuklu commited on
Commit
46da6a8
Β·
verified Β·
1 Parent(s): 3fb7122

Add README, tokenizer, results

Browse files
Files changed (2) hide show
  1. predict.py +32 -11
  2. tokenizer.json +0 -0
predict.py CHANGED
@@ -14,6 +14,16 @@ import sys
14
  import argparse
15
  import json
16
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # ── Argument parsing ────────────────────────────────────────────────────────
19
  parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
@@ -28,10 +38,14 @@ args = parser.parse_args()
28
 
29
 
30
  # ── Interactive prompts if args not provided ─────────────────────────────────
31
- def ask(prompt, default=None):
32
  suffix = f" [{default}]" if default else ""
33
- val = input(f"{prompt}{suffix}: ").strip()
34
- return val if val else default
 
 
 
 
35
 
36
 
37
  print("\n=== SASC Hate Speech Detector ===\n")
@@ -39,7 +53,7 @@ print("\n=== SASC Hate Speech Detector ===\n")
39
  # Model path
40
  model_path = args.model
41
  if not model_path:
42
- model_path = ask("Model path (.h5)", "model.h5")
43
 
44
  if not os.path.exists(model_path):
45
  print(f"Model not found: {model_path}")
@@ -50,7 +64,7 @@ tokenizer_path = args.tokenizer
50
  if not tokenizer_path:
51
  # look next to model file first
52
  candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
53
- tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json")
54
 
55
  if not os.path.exists(tokenizer_path):
56
  print(f"Tokenizer not found: {tokenizer_path}")
@@ -65,16 +79,23 @@ if not args.threshold and not args.text and not args.input:
65
  except ValueError:
66
  threshold = 0.5
67
 
68
- print(f"\nLoading model from {model_path}...")
 
 
 
69
  import tensorflow as tf
70
- model = tf.keras.models.load_model(model_path)
 
 
 
71
 
72
- print(f"Loading tokenizer from {tokenizer_path}...")
73
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
74
  from tensorflow.keras.preprocessing.sequence import pad_sequences
75
  with open(tokenizer_path) as f:
76
  tokenizer = tokenizer_from_json(f.read())
77
 
 
 
78
  MAX_LEN = 100
79
 
80
  def predict(texts):
@@ -128,14 +149,14 @@ if not input_path:
128
  print(results.to_string(index=False))
129
  print("="*60)
130
 
131
- out = args.output or ask("Save results to CSV? (leave blank to skip)", "")
132
  if out:
133
  results.to_csv(out, index=False)
134
  print(f"Saved to {out}")
135
  sys.exit(0)
136
 
137
  else:
138
- input_path = ask("CSV file path")
139
 
140
  if not os.path.exists(input_path):
141
  print(f"File not found: {input_path}")
@@ -178,7 +199,7 @@ print(df[[text_col, "predicted_label", "confidence"]].head(10).to_string(index=F
178
  output_path = args.output
179
  if not output_path:
180
  default_out = input_path.replace(".csv", "_predictions.csv")
181
- output_path = ask(f"\nSave full results to CSV", default_out)
182
 
183
  if output_path:
184
  df.to_csv(output_path, index=False)
 
14
  import argparse
15
  import json
16
 
17
+ # suppress TF logs
18
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
19
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
20
+
21
+ from prompt_toolkit import prompt
22
+ from prompt_toolkit.completion import PathCompleter
23
+ from prompt_toolkit.shortcuts import prompt as pt_prompt
24
+
25
+ path_completer = PathCompleter(expanduser=True)
26
+
27
 
28
  # ── Argument parsing ────────────────────────────────────────────────────────
29
  parser = argparse.ArgumentParser(description="SASC Hate Speech Detector")
 
38
 
39
 
40
  # ── Interactive prompts if args not provided ─────────────────────────────────
41
+ def ask(message, default=None, is_path=False):
42
  suffix = f" [{default}]" if default else ""
43
+ if is_path:
44
+ val = pt_prompt(f"{message}{suffix}: ", completer=path_completer).strip()
45
+ else:
46
+ val = input(f"{message}{suffix}: ").strip()
47
+ val = val if val else default
48
+ return os.path.expanduser(val) if val else val
49
 
50
 
51
  print("\n=== SASC Hate Speech Detector ===\n")
 
53
  # Model path
54
  model_path = args.model
55
  if not model_path:
56
+ model_path = ask("Model path (.h5)", "model.h5", is_path=True)
57
 
58
  if not os.path.exists(model_path):
59
  print(f"Model not found: {model_path}")
 
64
  if not tokenizer_path:
65
  # look next to model file first
66
  candidate = os.path.join(os.path.dirname(model_path), "tokenizer.json")
67
+ tokenizer_path = ask("Tokenizer path", candidate if os.path.exists(candidate) else "tokenizer.json", is_path=True)
68
 
69
  if not os.path.exists(tokenizer_path):
70
  print(f"Tokenizer not found: {tokenizer_path}")
 
79
  except ValueError:
80
  threshold = 0.5
81
 
82
+ print(f"\nLoading model from {model_path}")
83
+ print(f"Loading tokenizer from {tokenizer_path}")
84
+ import warnings
85
+ warnings.filterwarnings("ignore")
86
  import tensorflow as tf
87
+ import logging
88
+ tf.get_logger().setLevel(logging.ERROR)
89
+
90
+ model = tf.keras.models.load_model(model_path, compile=False)
91
 
 
92
  from tensorflow.keras.preprocessing.text import tokenizer_from_json
93
  from tensorflow.keras.preprocessing.sequence import pad_sequences
94
  with open(tokenizer_path) as f:
95
  tokenizer = tokenizer_from_json(f.read())
96
 
97
+ print(f"Model loaded β€” vocab size: {len(tokenizer.word_index)}")
98
+
99
  MAX_LEN = 100
100
 
101
  def predict(texts):
 
149
  print(results.to_string(index=False))
150
  print("="*60)
151
 
152
+ out = args.output or ask("Save results to CSV? (leave blank to skip)", "", is_path=True)
153
  if out:
154
  results.to_csv(out, index=False)
155
  print(f"Saved to {out}")
156
  sys.exit(0)
157
 
158
  else:
159
+ input_path = ask("CSV file path", is_path=True)
160
 
161
  if not os.path.exists(input_path):
162
  print(f"File not found: {input_path}")
 
199
  output_path = args.output
200
  if not output_path:
201
  default_out = input_path.replace(".csv", "_predictions.csv")
202
+ output_path = ask(f"\nSave full results to CSV", default_out, is_path=True)
203
 
204
  if output_path:
205
  df.to_csv(output_path, index=False)
tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff