Perth0603 commited on
Commit
ad3f1d2
·
verified ·
1 Parent(s): b418015

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -194
app.py CHANGED
@@ -1,184 +1,98 @@
1
  import os
2
  from typing import List, Optional, Dict
3
- import re
4
 
5
  import torch
6
- import nltk
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
- from nltk.corpus import stopwords
11
- from nltk.stem import PorterStemmer, WordNetLemmatizer
12
- from nltk.tokenize import word_tokenize
13
- from textblob import TextBlob
14
-
15
- # Download NLTK data
16
- try:
17
- nltk.data.find('tokenizers/punkt')
18
- except LookupError:
19
- nltk.download('punkt')
20
- nltk.download('stopwords')
21
- nltk.download('wordnet')
22
-
23
- # ✅ CHANGE THIS TO POINT TO YOUR MODEL REPOSITORY
24
- MODEL_ID = "Perth0603/phishing-email-mobilebert" # ← Your model storage repo
25
-
26
- app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
-
28
-
29
- # ============================================================================
30
- # TEXT PREPROCESSING CLASS
31
- # ============================================================================
32
- class TextPreprocessor:
33
- """NLP preprocessing for analysis and feature extraction"""
34
-
35
- def __init__(self):
36
- self.stemmer = PorterStemmer()
37
- self.lemmatizer = WordNetLemmatizer()
38
- self.stop_words = set(stopwords.words('english'))
39
-
40
- def tokenize(self, text: str) -> List[str]:
41
- """Break text into tokens"""
42
- return word_tokenize(text.lower())
43
-
44
- def remove_stopwords(self, tokens: List[str]) -> List[str]:
45
- """Remove common stop words"""
46
- return [token for token in tokens if token.isalnum() and token not in self.stop_words]
47
-
48
- def stem(self, tokens: List[str]) -> List[str]:
49
- """Reduce tokens to stems"""
50
- return [self.stemmer.stem(token) for token in tokens]
51
-
52
- def lemmatize(self, tokens: List[str]) -> List[str]:
53
- """Reduce tokens to lemmas"""
54
- return [self.lemmatizer.lemmatize(token) for token in tokens]
55
-
56
- def sentiment_analysis(self, text: str) -> Dict:
57
- """Analyze sentiment and phishing indicators"""
58
- blob = TextBlob(text)
59
- polarity = blob.sentiment.polarity
60
- subjectivity = blob.sentiment.subjectivity
61
-
62
- phishing_indicators = {
63
- "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
64
- "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
65
- "suspicious_urls": bool(re.search(r'http\S+|www\S+', text)),
66
- "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
67
- }
68
-
69
- return {
70
- "polarity": round(polarity, 4),
71
- "subjectivity": round(subjectivity, 4),
72
- "sentiment": "positive" if polarity > 0.1 else "negative" if polarity < -0.1 else "neutral",
73
- "is_persuasive": subjectivity > 0.5,
74
- "phishing_indicators": phishing_indicators
75
- }
76
-
77
- def preprocess(self, text: str) -> Dict:
78
- """Preprocessing for analysis"""
79
- tokens = self.tokenize(text)
80
- tokens_no_stop = self.remove_stopwords(tokens)
81
- stemmed = self.stem(tokens_no_stop)
82
- lemmatized = self.lemmatize(tokens_no_stop)
83
- sentiment = self.sentiment_analysis(text)
84
-
85
- return {
86
- "original_text": text,
87
- "tokens": tokens,
88
- "tokens_without_stopwords": tokens_no_stop,
89
- "stemmed_tokens": stemmed,
90
- "lemmatized_tokens": lemmatized,
91
- "sentiment": sentiment,
92
- "token_count": len(tokens_no_stop)
93
- }
94
 
95
 
96
- # ============================================================================
97
- # PYDANTIC MODELS
98
- # ============================================================================
99
  class PredictPayload(BaseModel):
100
  inputs: str
101
- include_preprocessing: bool = True
102
 
103
 
104
  class BatchPredictPayload(BaseModel):
105
  inputs: List[str]
106
- include_preprocessing: bool = True
107
 
108
 
109
  class LabeledText(BaseModel):
110
  text: str
111
- label: Optional[str] = None
112
 
113
 
114
  class EvalPayload(BaseModel):
115
  samples: List[LabeledText]
116
 
117
 
118
- # ============================================================================
119
- # GLOBAL VARIABLES
120
- # ============================================================================
121
  _tokenizer = None
122
  _model = None
123
  _device = "cpu"
124
- _preprocessor = None
 
 
125
 
126
 
127
- # ============================================================================
128
- # HELPER FUNCTIONS
129
- # ============================================================================
130
- def _normalize_label(txt: str) -> str:
131
- """Normalize label text"""
132
  t = (str(txt) if txt is not None else "").strip().upper()
133
- if t in ("PHISHING", "PHISH", "SPAM", "1"):
134
  return "PHISH"
135
- if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM", "0"):
136
  return "LEGIT"
 
137
  return t
138
 
139
 
140
  def _load_model():
141
- """Load model, tokenizer, and preprocessor"""
142
- global _tokenizer, _model, _device, _preprocessor
143
 
144
  if _tokenizer is None or _model is None:
145
  _device = "cuda" if torch.cuda.is_available() else "cpu"
146
- print(f"\n{'='*60}")
147
- print(f"Loading model: {MODEL_ID}")
148
- print(f"Device: {_device}")
149
- print(f"{'='*60}\n")
150
-
151
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
152
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
153
  _model.to(_device)
154
- _model.eval()
155
- _preprocessor = TextPreprocessor()
156
 
157
- # Warm-up
158
  with torch.no_grad():
159
  _ = _model(
160
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
161
  .to(_device)
162
  ).logits
163
 
164
- # Check label mapping
165
- id2label = getattr(_model.config, "id2label", {})
166
- print(f"Model labels: {id2label}")
167
- print(f"{'='*60}\n")
168
-
169
-
170
- def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
171
- """Predict with correct label mapping"""
 
 
 
 
 
 
 
 
172
  _load_model()
173
  if not texts:
174
  return []
175
 
176
- # Get preprocessing info
177
- preprocessing_info = None
178
- if include_preprocessing:
179
- preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
180
-
181
- # Tokenize
182
  enc = _tokenizer(
183
  texts,
184
  return_tensors="pt",
@@ -188,115 +102,95 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
188
  )
189
  enc = {k: v.to(_device) for k, v in enc.items()}
190
 
191
- # Predict
192
  with torch.no_grad():
193
  logits = _model(**enc).logits
194
- probs = torch.softmax(logits, dim=-1)
195
 
196
- # Get labels from model config
197
- id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
 
 
 
198
 
199
  outputs: List[Dict] = []
200
- for text_idx in range(probs.shape[0]):
201
- p = probs[text_idx]
202
-
203
- # Get prediction
204
- predicted_idx = int(torch.argmax(p).item())
205
- predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
206
- predicted_label_norm = _normalize_label(predicted_label_raw)
207
- predicted_prob = float(p[predicted_idx].item())
208
-
209
- # Build probability breakdown
210
- prob_breakdown = {}
211
- for i in range(len(p)):
212
- label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
213
- prob_breakdown[label] = round(float(p[i].item()), 4)
214
-
215
- output = {
216
- "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
217
- "label": predicted_label_norm,
218
- "raw_label": predicted_label_raw,
219
- "is_phish": predicted_label_norm == "PHISH",
220
- "confidence": round(predicted_prob * 100, 2),
221
- "score": round(predicted_prob, 4),
222
- "probs": prob_breakdown,
223
- }
224
-
225
- if include_preprocessing and preprocessing_info:
226
- output["preprocessing"] = preprocessing_info[text_idx]
227
-
228
- outputs.append(output)
229
 
230
  return outputs
231
 
232
 
233
- # ============================================================================
234
- # API ENDPOINTS
235
- # ============================================================================
236
-
237
  @app.get("/")
238
  def root():
239
- """Root endpoint"""
240
  _load_model()
241
  return {
242
  "status": "ok",
243
  "model": MODEL_ID,
244
- "device": _device,
245
  }
246
 
247
 
248
  @app.get("/debug/labels")
249
  def debug_labels():
250
- """View model configuration"""
251
  _load_model()
252
-
253
  return {
254
- "status": "ok",
255
- "model_id": MODEL_ID,
256
  "id2label": getattr(_model.config, "id2label", {}),
257
  "label2id": getattr(_model.config, "label2id", {}),
258
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
259
  "device": _device,
 
260
  }
261
 
262
 
263
- @app.post("/debug/preprocessing")
264
- def debug_preprocessing(payload: PredictPayload):
265
- """Debug preprocessing"""
266
- try:
267
- _load_model()
268
- preprocessing = _preprocessor.preprocess(payload.inputs)
269
- return preprocessing
270
- except Exception as e:
271
- raise HTTPException(status_code=500, detail=str(e))
272
-
273
-
274
  @app.post("/predict")
275
  def predict(payload: PredictPayload):
276
- """Single prediction"""
277
  try:
278
- res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
279
  return res[0]
280
  except Exception as e:
281
- raise HTTPException(status_code=500, detail=str(e))
282
 
283
 
284
  @app.post("/predict-batch")
285
  def predict_batch(payload: BatchPredictPayload):
286
- """Batch predictions"""
287
  try:
288
- return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
289
  except Exception as e:
290
- raise HTTPException(status_code=500, detail=str(e))
291
 
292
 
293
  @app.post("/evaluate")
294
  def evaluate(payload: EvalPayload):
295
- """Evaluate on labeled samples"""
 
 
 
 
296
  try:
297
  texts = [s.text for s in payload.samples]
298
- gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
299
- preds = _predict_texts(texts, include_preprocessing=False)
300
 
301
  total = len(preds)
302
  correct = 0
@@ -315,16 +209,16 @@ def evaluate(payload: EvalPayload):
315
  acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
316
 
317
  return {
318
- "accuracy": round(acc, 4) if acc else None,
319
  "total": total,
320
- "correct": correct,
321
  "predictions": preds,
322
  "per_class": per_class,
323
  }
324
  except Exception as e:
325
- raise HTTPException(status_code=500, detail=str(e))
326
 
327
 
328
  if __name__ == "__main__":
 
329
  import uvicorn
330
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
  from typing import List, Optional, Dict
 
3
 
4
  import torch
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+ # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
10
+ MODEL_ID = (
11
+ os.environ.get("MODEL_ID")
12
+ or os.environ.get("HF_MODEL_ID")
13
+ or "Perth0603/phishing-email-mobilebert"
14
+ )
15
+
16
+ app = FastAPI(title="Phishing Text Classifier (Model-Authoritative)", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
 
 
 
19
  class PredictPayload(BaseModel):
20
  inputs: str
 
21
 
22
 
23
  class BatchPredictPayload(BaseModel):
24
  inputs: List[str]
 
25
 
26
 
27
  class LabeledText(BaseModel):
28
  text: str
29
+ label: Optional[str] = None # optional ground truth for quick eval (accepts text)
30
 
31
 
32
  class EvalPayload(BaseModel):
33
  samples: List[LabeledText]
34
 
35
 
 
 
 
36
  _tokenizer = None
37
  _model = None
38
  _device = "cpu"
39
+
40
+ # Cached normalized mapping/meta
41
+ _NORM_LABELS_BY_IDX = None # normalized labels ordered by model indices
42
 
43
 
44
+ def _normalize_label_text_only(txt: str) -> str:
45
+ """
46
+ Normalize model label text to PHISH/LEGIT when possible.
47
+ If unfamiliar, return the uppercased original token.
48
+ """
49
  t = (str(txt) if txt is not None else "").strip().upper()
50
+ if t in ("PHISHING", "PHISH", "SPAM"):
51
  return "PHISH"
52
+ if t in ("LEGIT", "LEGITIMATE", "SAFE", "HAM"):
53
  return "LEGIT"
54
+ # keep other label names as-is (uppercased) so we don't force an incorrect mapping
55
  return t
56
 
57
 
58
  def _load_model():
59
+ global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX
 
60
 
61
  if _tokenizer is None or _model is None:
62
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
63
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
64
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
65
  _model.to(_device)
66
+ _model.eval() # important: disable dropout etc.
 
67
 
68
+ # Warm-up (silent)
69
  with torch.no_grad():
70
  _ = _model(
71
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
72
  .to(_device)
73
  ).logits
74
 
75
+ # Read and normalize model labels (by index)
76
+ id2label = getattr(_model.config, "id2label", {}) or {}
77
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
78
+ _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
79
+
80
+
81
+ def _predict_texts(texts: List[str]) -> List[Dict]:
82
+ """
83
+ Predict and return strictly model-authoritative outputs:
84
+ - label: normalized model label (PHISH/LEGIT or other model label uppercased)
85
+ - raw_label: original id2label string from model.config
86
+ - is_phish: boolean derived from normalized label (True if normalized == "PHISH")
87
+ - score: probability of predicted class
88
+ - probs: dict of normalized label -> probability (or CLASS_i keys if unknown)
89
+ - predicted_index: argmax index
90
+ """
91
  _load_model()
92
  if not texts:
93
  return []
94
 
95
+ # Tokenize batch
 
 
 
 
 
96
  enc = _tokenizer(
97
  texts,
98
  return_tensors="pt",
 
102
  )
103
  enc = {k: v.to(_device) for k, v in enc.items()}
104
 
 
105
  with torch.no_grad():
106
  logits = _model(**enc).logits
107
+ probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
108
 
109
+ # Use the model’s own mapping
110
+ id2label = getattr(_model.config, "id2label", None) or {}
111
+ labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
112
+ # normalized labels where possible
113
+ labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
114
 
115
  outputs: List[Dict] = []
116
+ for i in range(probs.shape[0]):
117
+ p = probs[i]
118
+ idx = int(torch.argmax(p).item())
119
+
120
+ raw_label = labels_by_idx_raw[idx]
121
+ norm_label = labels_by_idx_norm[idx] # normalized where possible
122
+
123
+ # Build probability map keyed by normalized labels when available,
124
+ # otherwise fallback to CLASS_i keys to avoid collision
125
+ prob_map: Dict[str, float] = {}
126
+ for j, lbl_norm in enumerate(labels_by_idx_norm):
127
+ key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
128
+ prob_map[key] = float(p[j].item())
129
+
130
+ outputs.append(
131
+ {
132
+ "label": norm_label, # authoritative label (model-driven, normalized)
133
+ "raw_label": raw_label, # original model id2label value
134
+ "is_phish": True if norm_label == "PHISH" else False,
135
+ "score": float(p[idx].item()), # probability of predicted class
136
+ "probs": prob_map, # per-class probabilities (keys normalized or CLASS_i)
137
+ "predicted_index": idx,
138
+ }
139
+ )
 
 
 
 
 
140
 
141
  return outputs
142
 
143
 
 
 
 
 
144
  @app.get("/")
145
  def root():
 
146
  _load_model()
147
  return {
148
  "status": "ok",
149
  "model": MODEL_ID,
150
+ "note": "This service returns predictions exactly as the model decides (label derived from model.config.id2label). Frontend should use `label` or `is_phish` as authority."
151
  }
152
 
153
 
154
  @app.get("/debug/labels")
155
  def debug_labels():
 
156
  _load_model()
 
157
  return {
 
 
158
  "id2label": getattr(_model.config, "id2label", {}),
159
  "label2id": getattr(_model.config, "label2id", {}),
160
  "num_labels": int(getattr(_model.config, "num_labels", 0)),
161
  "device": _device,
162
+ "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
163
  }
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
166
  @app.post("/predict")
167
  def predict(payload: PredictPayload):
 
168
  try:
169
+ res = _predict_texts([payload.inputs])
170
  return res[0]
171
  except Exception as e:
172
+ raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
173
 
174
 
175
  @app.post("/predict-batch")
176
  def predict_batch(payload: BatchPredictPayload):
 
177
  try:
178
+ return _predict_texts(payload.inputs)
179
  except Exception as e:
180
+ raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
181
 
182
 
183
  @app.post("/evaluate")
184
  def evaluate(payload: EvalPayload):
185
+ """
186
+ Quick on-the-spot test with provided labeled samples.
187
+ The provided labels are interpreted as text labels (PHISH/LEGIT/etc.) — evaluation is done
188
+ by comparing normalized GT text to model's normalized prediction (no 0/1 dataset mapping applied).
189
+ """
190
  try:
191
  texts = [s.text for s in payload.samples]
192
+ gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
193
+ preds = _predict_texts(texts)
194
 
195
  total = len(preds)
196
  correct = 0
 
209
  acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
210
 
211
  return {
212
+ "accuracy": acc,
213
  "total": total,
 
214
  "predictions": preds,
215
  "per_class": per_class,
216
  }
217
  except Exception as e:
218
+ raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
219
 
220
 
221
  if __name__ == "__main__":
222
+ # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
223
  import uvicorn
224
+ uvicorn.run(app, host="0.0.0.0", port=8000)