Perth0603 commited on
Commit
6823e29
·
verified ·
1 Parent(s): 3a83600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -46
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from typing import List, Optional, Dict
3
  import re
 
4
 
5
  import torch
6
  import nltk
@@ -30,7 +31,7 @@ app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.
30
 
31
 
32
  # ============================================================================
33
- # TEXT PREPROCESSING CLASS (FOR ANALYSIS ONLY, NOT FOR MODEL INPUT)
34
  # ============================================================================
35
  class TextPreprocessor:
36
  """NLP preprocessing for analysis and feature extraction"""
@@ -78,7 +79,7 @@ class TextPreprocessor:
78
  }
79
 
80
  def preprocess(self, text: str) -> Dict:
81
- """Preprocessing for analysis (NOT for model)"""
82
  tokens = self.tokenize(text)
83
  tokens_no_stop = self.remove_stopwords(tokens)
84
  stemmed = self.stem(tokens_no_stop)
@@ -125,29 +126,79 @@ _tokenizer = None
125
  _model = None
126
  _device = "cpu"
127
  _preprocessor = None
128
- _NORM_LABELS_BY_IDX = None
129
 
130
 
131
  # ============================================================================
132
  # HELPER FUNCTIONS
133
  # ============================================================================
134
- def _normalize_label_text_only(txt: str) -> str:
135
- """Normalize model label text"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  t = (str(txt) if txt is not None else "").strip().upper()
137
- if t in ("PHISHING", "PHISH", "SPAM"):
138
  return "PHISH"
139
- if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
140
  return "LEGIT"
141
  return t
142
 
143
 
144
  def _load_model():
145
  """Load model, tokenizer, and preprocessor"""
146
- global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX, _preprocessor
147
 
148
  if _tokenizer is None or _model is None:
149
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
150
  print(f"Loading model on device: {_device}")
 
 
151
 
152
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
153
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
@@ -155,6 +206,9 @@ def _load_model():
155
  _model.eval()
156
  _preprocessor = TextPreprocessor()
157
 
 
 
 
158
  # Warm-up
159
  with torch.no_grad():
160
  _ = _model(
@@ -162,36 +216,28 @@ def _load_model():
162
  .to(_device)
163
  ).logits
164
 
165
- # Read and normalize model labels
166
- id2label = getattr(_model.config, "id2label", {}) or {}
167
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
168
- _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
169
-
170
- print(f"Model loaded successfully")
171
- print(f"ID2Label: {id2label}")
172
- print(f"Normalized labels: {_NORM_LABELS_BY_IDX}")
173
 
174
 
175
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
176
  """
177
- Predict using ORIGINAL text (NO cleaning).
178
- Preprocessing is for ANALYSIS only, not for model input.
179
  """
180
  _load_model()
181
  if not texts:
182
  return []
183
 
184
- # IMPORTANT: Use original text for model, NOT cleaned text!
185
- model_inputs = texts
186
-
187
- # Get preprocessing info for analysis
188
  preprocessing_info = None
189
  if include_preprocessing:
190
  preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
191
 
192
- # Tokenize batch for model
193
  enc = _tokenizer(
194
- model_inputs,
195
  return_tensors="pt",
196
  padding=True,
197
  truncation=True,
@@ -199,36 +245,44 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
199
  )
200
  enc = {k: v.to(_device) for k, v in enc.items()}
201
 
202
- # Get predictions
203
  with torch.no_grad():
204
  logits = _model(**enc).logits
205
  probs = torch.softmax(logits, dim=-1)
206
 
207
- id2label = getattr(_model.config, "id2label", None) or {}
208
- labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
209
- labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
 
 
 
 
 
210
 
211
  outputs: List[Dict] = []
212
  for i in range(probs.shape[0]):
213
  p = probs[i]
214
  idx = int(torch.argmax(p).item())
215
 
216
- raw_label = labels_by_idx_raw[idx]
217
- norm_label = labels_by_idx_norm[idx]
218
 
 
219
  prob_map: Dict[str, float] = {}
220
- for j, lbl_norm in enumerate(labels_by_idx_norm):
221
- key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
222
- prob_map[key] = float(p[j].item())
223
 
224
  output = {
 
225
  "label": norm_label,
226
  "raw_label": raw_label,
227
- "is_phish": True if norm_label == "PHISH" else False,
228
  "score": round(float(p[idx].item()), 4),
229
  "confidence": round(float(p[idx].item()), 4),
230
- "probs": {k: round(v, 4) for k, v in prob_map.items()},
231
  "predicted_index": idx,
 
 
232
  }
233
 
234
  if include_preprocessing and preprocessing_info:
@@ -251,26 +305,33 @@ def root():
251
  "status": "ok",
252
  "model": MODEL_ID,
253
  "device": _device,
254
- "note": "Model uses ORIGINAL text for predictions. Preprocessing is for analysis only.",
255
  }
256
 
257
 
258
  @app.get("/debug/labels")
259
  def debug_labels():
260
- """View model configuration"""
261
  _load_model()
 
 
 
 
 
262
  return {
263
- "id2label": getattr(_model.config, "id2label", {}),
264
- "label2id": getattr(_model.config, "label2id", {}),
265
- "num_labels": int(getattr(_model.config, "num_labels", 0)),
 
 
266
  "device": _device,
267
- "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
268
  }
269
 
270
 
271
  @app.post("/debug/preprocessing")
272
  def debug_preprocessing(payload: PredictPayload):
273
- """Debug preprocessing output"""
274
  try:
275
  _load_model()
276
  preprocessing = _preprocessor.preprocess(payload.inputs)
@@ -279,7 +340,7 @@ def debug_preprocessing(payload: PredictPayload):
279
  "preprocessing": preprocessing
280
  }
281
  except Exception as e:
282
- raise HTTPException(status_code=500, detail=f"Preprocessing error: {e}")
283
 
284
 
285
  @app.post("/predict")
@@ -289,7 +350,7 @@ def predict(payload: PredictPayload):
289
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
290
  return res[0]
291
  except Exception as e:
292
- raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
293
 
294
 
295
  @app.post("/predict-batch")
@@ -298,7 +359,7 @@ def predict_batch(payload: BatchPredictPayload):
298
  try:
299
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
300
  except Exception as e:
301
- raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
302
 
303
 
304
  @app.post("/evaluate")
@@ -306,7 +367,7 @@ def evaluate(payload: EvalPayload):
306
  """Evaluate on labeled samples"""
307
  try:
308
  texts = [s.text for s in payload.samples]
309
- gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
310
  preds = _predict_texts(texts, include_preprocessing=False)
311
 
312
  total = len(preds)
@@ -333,7 +394,7 @@ def evaluate(payload: EvalPayload):
333
  "per_class": per_class,
334
  }
335
  except Exception as e:
336
- raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
337
 
338
 
339
  if __name__ == "__main__":
 
1
  import os
2
  from typing import List, Optional, Dict
3
  import re
4
+ import json
5
 
6
  import torch
7
  import nltk
 
31
 
32
 
33
  # ============================================================================
34
+ # TEXT PREPROCESSING CLASS
35
  # ============================================================================
36
  class TextPreprocessor:
37
  """NLP preprocessing for analysis and feature extraction"""
 
79
  }
80
 
81
  def preprocess(self, text: str) -> Dict:
82
+ """Preprocessing for analysis"""
83
  tokens = self.tokenize(text)
84
  tokens_no_stop = self.remove_stopwords(tokens)
85
  stemmed = self.stem(tokens_no_stop)
 
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
129
+ _LABEL_MAPPING = None
130
 
131
 
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
135
+ def _get_label_mapping():
136
+ """
137
+ Get complete label mapping.
138
+ If model config is incomplete, use fallback mapping.
139
+ """
140
+ global _model, _LABEL_MAPPING
141
+
142
+ if _model is None:
143
+ return None
144
+
145
+ id2label = getattr(_model.config, "id2label", {}) or {}
146
+
147
+ # Check if mapping is incomplete (missing label 0)
148
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
149
+
150
+ print(f"DEBUG: num_labels = {num_labels}")
151
+ print(f"DEBUG: id2label from config = {id2label}")
152
+
153
+ # If incomplete, use fallback
154
+ if len(id2label) < num_labels:
155
+ print(f"WARNING: Incomplete label mapping detected!")
156
+ print(f"Expected {num_labels} labels, got {len(id2label)}")
157
+
158
+ # Try to load from labels.json if available
159
+ try:
160
+ import pkg_resources
161
+ model_path = pkg_resources.resource_filename(__name__, 'models')
162
+ labels_path = os.path.join(model_path, 'labels.json')
163
+ if os.path.exists(labels_path):
164
+ with open(labels_path, 'r') as f:
165
+ labels_data = json.load(f)
166
+ id2label = labels_data.get("id2label", {})
167
+ print(f"Loaded labels from labels.json: {id2label}")
168
+ except:
169
+ pass
170
+
171
+ # Final fallback mapping
172
+ if len(id2label) < 2:
173
+ print("Using fallback label mapping: 0=LEGIT, 1=PHISH")
174
+ id2label = {
175
+ "0": "LEGIT",
176
+ "1": "PHISH"
177
+ }
178
+
179
+ return id2label
180
+
181
+
182
+ def _normalize_label(txt: str) -> str:
183
+ """Normalize label text"""
184
  t = (str(txt) if txt is not None else "").strip().upper()
185
+ if t in ("PHISHING", "PHISH", "SPAM", "1"):
186
  return "PHISH"
187
+ if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"):
188
  return "LEGIT"
189
  return t
190
 
191
 
192
  def _load_model():
193
  """Load model, tokenizer, and preprocessor"""
194
+ global _tokenizer, _model, _device, _preprocessor, _LABEL_MAPPING
195
 
196
  if _tokenizer is None or _model is None:
197
  _device = "cuda" if torch.cuda.is_available() else "cpu"
198
+ print(f"\n{'='*60}")
199
  print(f"Loading model on device: {_device}")
200
+ print(f"Model ID: {MODEL_ID}")
201
+ print(f"{'='*60}\n")
202
 
203
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
204
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
 
206
  _model.eval()
207
  _preprocessor = TextPreprocessor()
208
 
209
+ # Get label mapping
210
+ _LABEL_MAPPING = _get_label_mapping()
211
+
212
  # Warm-up
213
  with torch.no_grad():
214
  _ = _model(
 
216
  .to(_device)
217
  ).logits
218
 
 
 
219
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
220
+ print(f"Number of labels: {num_labels}")
221
+ print(f"Label mapping: {_LABEL_MAPPING}")
222
+ print(f"{'='*60}\n")
 
 
223
 
224
 
225
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
226
  """
227
+ Predict with corrected label mapping
 
228
  """
229
  _load_model()
230
  if not texts:
231
  return []
232
 
233
+ # Get preprocessing info
 
 
 
234
  preprocessing_info = None
235
  if include_preprocessing:
236
  preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
237
 
238
+ # Tokenize
239
  enc = _tokenizer(
240
+ texts,
241
  return_tensors="pt",
242
  padding=True,
243
  truncation=True,
 
245
  )
246
  enc = {k: v.to(_device) for k, v in enc.items()}
247
 
248
+ # Predict
249
  with torch.no_grad():
250
  logits = _model(**enc).logits
251
  probs = torch.softmax(logits, dim=-1)
252
 
253
+ # Build label list from mapping
254
+ num_labels = probs.shape[-1]
255
+ labels_by_idx = []
256
+ for i in range(num_labels):
257
+ label = _LABEL_MAPPING.get(str(i), f"LABEL_{i}")
258
+ labels_by_idx.append(label)
259
+
260
+ print(f"DEBUG: Using labels: {labels_by_idx}")
261
 
262
  outputs: List[Dict] = []
263
  for i in range(probs.shape[0]):
264
  p = probs[i]
265
  idx = int(torch.argmax(p).item())
266
 
267
+ raw_label = labels_by_idx[idx]
268
+ norm_label = _normalize_label(raw_label)
269
 
270
+ # Build probability map
271
  prob_map: Dict[str, float] = {}
272
+ for j in range(len(labels_by_idx)):
273
+ label_norm = _normalize_label(labels_by_idx[j])
274
+ prob_map[label_norm] = float(p[j].item())
275
 
276
  output = {
277
+ "text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
278
  "label": norm_label,
279
  "raw_label": raw_label,
280
+ "is_phish": norm_label == "PHISH",
281
  "score": round(float(p[idx].item()), 4),
282
  "confidence": round(float(p[idx].item()), 4),
 
283
  "predicted_index": idx,
284
+ "probs": {k: round(v, 4) for k, v in prob_map.items()},
285
+ "all_probs_raw": [round(float(p[j].item()), 4) for j in range(len(labels_by_idx))],
286
  }
287
 
288
  if include_preprocessing and preprocessing_info:
 
305
  "status": "ok",
306
  "model": MODEL_ID,
307
  "device": _device,
308
+ "label_mapping": _LABEL_MAPPING,
309
  }
310
 
311
 
312
  @app.get("/debug/labels")
313
  def debug_labels():
314
+ """View complete model configuration"""
315
  _load_model()
316
+
317
+ id2label_raw = getattr(_model.config, "id2label", {}) or {}
318
+ label2id_raw = getattr(_model.config, "label2id", {}) or {}
319
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
320
+
321
  return {
322
+ "status": "ok",
323
+ "config_id2label": id2label_raw,
324
+ "config_label2id": label2id_raw,
325
+ "config_num_labels": num_labels,
326
+ "applied_label_mapping": _LABEL_MAPPING,
327
  "device": _device,
328
+ "note": "If config_id2label is incomplete, applied_label_mapping is used"
329
  }
330
 
331
 
332
  @app.post("/debug/preprocessing")
333
  def debug_preprocessing(payload: PredictPayload):
334
+ """Debug preprocessing"""
335
  try:
336
  _load_model()
337
  preprocessing = _preprocessor.preprocess(payload.inputs)
 
340
  "preprocessing": preprocessing
341
  }
342
  except Exception as e:
343
+ raise HTTPException(status_code=500, detail=f"Error: {e}")
344
 
345
 
346
  @app.post("/predict")
 
350
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
351
  return res[0]
352
  except Exception as e:
353
+ raise HTTPException(status_code=500, detail=f"Error: {e}")
354
 
355
 
356
  @app.post("/predict-batch")
 
359
  try:
360
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
361
  except Exception as e:
362
+ raise HTTPException(status_code=500, detail=f"Error: {e}")
363
 
364
 
365
  @app.post("/evaluate")
 
367
  """Evaluate on labeled samples"""
368
  try:
369
  texts = [s.text for s in payload.samples]
370
+ gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
371
  preds = _predict_texts(texts, include_preprocessing=False)
372
 
373
  total = len(preds)
 
394
  "per_class": per_class,
395
  }
396
  except Exception as e:
397
+ raise HTTPException(status_code=500, detail=f"Error: {e}")
398
 
399
 
400
  if __name__ == "__main__":