Perth0603 commited on
Commit
3a83600
·
verified ·
1 Parent(s): 60cd459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -82
app.py CHANGED
@@ -12,16 +12,14 @@ from nltk.stem import PorterStemmer, WordNetLemmatizer
12
  from nltk.tokenize import word_tokenize
13
  from textblob import TextBlob
14
 
15
- # Download NLTK data (runs once)
16
  try:
17
  nltk.data.find('tokenizers/punkt')
18
  except LookupError:
19
  nltk.download('punkt')
20
  nltk.download('stopwords')
21
  nltk.download('wordnet')
22
- nltk.download('averaged_perceptron_tagger')
23
 
24
- # Prefer MODEL_ID, fall back to HF_MODEL_ID, then default
25
  MODEL_ID = (
26
  os.environ.get("MODEL_ID")
27
  or os.environ.get("HF_MODEL_ID")
@@ -32,10 +30,10 @@ app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.
32
 
33
 
34
  # ============================================================================
35
- # TEXT PREPROCESSING CLASS
36
  # ============================================================================
37
  class TextPreprocessor:
38
- """Complete NLP preprocessing pipeline"""
39
 
40
  def __init__(self):
41
  self.stemmer = PorterStemmer()
@@ -44,15 +42,14 @@ class TextPreprocessor:
44
 
45
  def tokenize(self, text: str) -> List[str]:
46
  """Break text into tokens"""
47
- text_lower = text.lower()
48
- return word_tokenize(text_lower)
49
 
50
  def remove_stopwords(self, tokens: List[str]) -> List[str]:
51
  """Remove common stop words"""
52
  return [token for token in tokens if token.isalnum() and token not in self.stop_words]
53
 
54
  def stem(self, tokens: List[str]) -> List[str]:
55
- """Reduce tokens to stems (Porter Stemmer)"""
56
  return [self.stemmer.stem(token) for token in tokens]
57
 
58
  def lemmatize(self, tokens: List[str]) -> List[str]:
@@ -60,15 +57,15 @@ class TextPreprocessor:
60
  return [self.lemmatizer.lemmatize(token) for token in tokens]
61
 
62
  def sentiment_analysis(self, text: str) -> Dict:
63
- """Analyze sentiment polarity, subjectivity, and detect phishing indicators"""
64
  blob = TextBlob(text)
65
- polarity = blob.sentiment.polarity # -1 (negative) to 1 (positive)
66
- subjectivity = blob.sentiment.subjectivity # 0 (objective) to 1 (subjective)
67
 
68
- # Detect persuasive/emotional language (common in phishing)
69
  phishing_indicators = {
70
  "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
71
  "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
 
72
  "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
73
  }
74
 
@@ -80,46 +77,20 @@ class TextPreprocessor:
80
  "phishing_indicators": phishing_indicators
81
  }
82
 
83
- def clean_text(self, text: str) -> str:
84
- """Clean URLs, special characters, extra spaces"""
85
- # Remove URLs
86
- text = re.sub(r'http\S+|www\S+', '', text)
87
- # Remove email addresses
88
- text = re.sub(r'\S+@\S+', '', text)
89
- # Remove special characters but keep spaces
90
- text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
91
- # Remove extra whitespace
92
- text = re.sub(r'\s+', ' ', text).strip()
93
- return text
94
-
95
  def preprocess(self, text: str) -> Dict:
96
- """Complete preprocessing pipeline"""
97
- # Step 1: Clean
98
- cleaned_text = self.clean_text(text)
99
-
100
- # Step 2: Tokenize
101
- tokens = self.tokenize(cleaned_text)
102
-
103
- # Step 3: Remove stopwords
104
  tokens_no_stop = self.remove_stopwords(tokens)
105
-
106
- # Step 4: Stem
107
  stemmed = self.stem(tokens_no_stop)
108
-
109
- # Step 5: Lemmatize
110
  lemmatized = self.lemmatize(tokens_no_stop)
111
-
112
- # Step 6: Sentiment analysis
113
  sentiment = self.sentiment_analysis(text)
114
 
115
  return {
116
  "original_text": text,
117
- "cleaned_text": cleaned_text,
118
  "tokens": tokens,
119
  "tokens_without_stopwords": tokens_no_stop,
120
  "stemmed_tokens": stemmed,
121
  "lemmatized_tokens": lemmatized,
122
- "processed_text": " ".join(lemmatized), # Use lemmatized for model input
123
  "sentiment": sentiment,
124
  "token_count": len(tokens_no_stop)
125
  }
@@ -161,10 +132,7 @@ _NORM_LABELS_BY_IDX = None
161
  # HELPER FUNCTIONS
162
  # ============================================================================
163
  def _normalize_label_text_only(txt: str) -> str:
164
- """
165
- Normalize model label text to PHISH/LEGIT when possible.
166
- If unfamiliar, return the uppercased original token.
167
- """
168
  t = (str(txt) if txt is not None else "").strip().upper()
169
  if t in ("PHISHING", "PHISH", "SPAM"):
170
  return "PHISH"
@@ -174,7 +142,7 @@ def _normalize_label_text_only(txt: str) -> str:
174
 
175
 
176
  def _load_model():
177
- """Load model, tokenizer, and preprocessor on first use"""
178
  global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX, _preprocessor
179
 
180
  if _tokenizer is None or _model is None:
@@ -184,44 +152,44 @@ def _load_model():
184
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
185
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
186
  _model.to(_device)
187
- _model.eval() # important: disable dropout etc.
188
  _preprocessor = TextPreprocessor()
189
 
190
- # Warm-up (silent)
191
  with torch.no_grad():
192
  _ = _model(
193
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
194
  .to(_device)
195
  ).logits
196
 
197
- # Read and normalize model labels (by index)
198
  id2label = getattr(_model.config, "id2label", {}) or {}
199
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
200
  _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
201
 
202
- print(f"Model loaded successfully. Number of labels: {num_labels}")
203
- print(f"Label mapping: {id2label}")
204
  print(f"Normalized labels: {_NORM_LABELS_BY_IDX}")
205
 
206
 
207
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
208
  """
209
- Predict and return strictly model-authoritative outputs with enhanced debugging.
 
210
  """
211
  _load_model()
212
  if not texts:
213
  return []
214
 
215
- # Step 1: Preprocess texts if requested
 
 
 
216
  preprocessing_info = None
217
  if include_preprocessing:
218
  preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
219
- # Use lemmatized text for model input
220
- model_inputs = [prep["processed_text"] for prep in preprocessing_info]
221
- else:
222
- model_inputs = texts
223
 
224
- # Step 2: Tokenize batch for model
225
  enc = _tokenizer(
226
  model_inputs,
227
  return_tensors="pt",
@@ -231,12 +199,11 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
231
  )
232
  enc = {k: v.to(_device) for k, v in enc.items()}
233
 
234
- # Step 3: Get predictions
235
  with torch.no_grad():
236
  logits = _model(**enc).logits
237
- probs = torch.softmax(logits, dim=-1) # [batch, num_labels]
238
 
239
- # Step 4: Build probability maps
240
  id2label = getattr(_model.config, "id2label", None) or {}
241
  labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
242
  labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
@@ -249,7 +216,6 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
249
  raw_label = labels_by_idx_raw[idx]
250
  norm_label = labels_by_idx_norm[idx]
251
 
252
- # Build probability map keyed by normalized labels
253
  prob_map: Dict[str, float] = {}
254
  for j, lbl_norm in enumerate(labels_by_idx_norm):
255
  key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
@@ -260,13 +226,11 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
260
  "raw_label": raw_label,
261
  "is_phish": True if norm_label == "PHISH" else False,
262
  "score": round(float(p[idx].item()), 4),
 
263
  "probs": {k: round(v, 4) for k, v in prob_map.items()},
264
  "predicted_index": idx,
265
- "all_logits": [round(float(logits[i][j].item()), 4) for j in range(logits.shape[1])], # DEBUG
266
- "raw_probs": [round(float(p[j].item()), 4) for j in range(len(p))], # DEBUG
267
  }
268
 
269
- # Add preprocessing info if requested
270
  if include_preprocessing and preprocessing_info:
271
  output["preprocessing"] = preprocessing_info[i]
272
 
@@ -281,26 +245,19 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
281
 
282
  @app.get("/")
283
  def root():
284
- """Root endpoint - shows API status"""
285
  _load_model()
286
  return {
287
  "status": "ok",
288
  "model": MODEL_ID,
289
  "device": _device,
290
- "note": "Text preprocessing with stemming, lemmatization, stopword removal, and sentiment analysis enabled by default",
291
- "endpoints": {
292
- "/predict": "POST - Single text prediction",
293
- "/predict-batch": "POST - Batch predictions",
294
- "/evaluate": "POST - Evaluate with labeled samples",
295
- "/debug/labels": "GET - View model label configuration",
296
- "/debug/preprocessing": "POST - Debug preprocessing output only"
297
- }
298
  }
299
 
300
 
301
  @app.get("/debug/labels")
302
  def debug_labels():
303
- """Debug endpoint - view model label configuration"""
304
  _load_model()
305
  return {
306
  "id2label": getattr(_model.config, "id2label", {}),
@@ -313,7 +270,7 @@ def debug_labels():
313
 
314
  @app.post("/debug/preprocessing")
315
  def debug_preprocessing(payload: PredictPayload):
316
- """Debug endpoint - view preprocessing output only (no model prediction)"""
317
  try:
318
  _load_model()
319
  preprocessing = _preprocessor.preprocess(payload.inputs)
@@ -327,7 +284,7 @@ def debug_preprocessing(payload: PredictPayload):
327
 
328
  @app.post("/predict")
329
  def predict(payload: PredictPayload):
330
- """Single text prediction with optional preprocessing details"""
331
  try:
332
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
333
  return res[0]
@@ -337,7 +294,7 @@ def predict(payload: PredictPayload):
337
 
338
  @app.post("/predict-batch")
339
  def predict_batch(payload: BatchPredictPayload):
340
- """Batch text predictions with optional preprocessing details"""
341
  try:
342
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
343
  except Exception as e:
@@ -346,10 +303,7 @@ def predict_batch(payload: BatchPredictPayload):
346
 
347
  @app.post("/evaluate")
348
  def evaluate(payload: EvalPayload):
349
- """
350
- Evaluate model on labeled samples.
351
- Compares model predictions against provided ground truth labels.
352
- """
353
  try:
354
  texts = [s.text for s in payload.samples]
355
  gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
@@ -383,6 +337,5 @@ def evaluate(payload: EvalPayload):
383
 
384
 
385
  if __name__ == "__main__":
386
- # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
387
  import uvicorn
388
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
12
  from nltk.tokenize import word_tokenize
13
  from textblob import TextBlob
14
 
15
+ # Download NLTK data
16
  try:
17
  nltk.data.find('tokenizers/punkt')
18
  except LookupError:
19
  nltk.download('punkt')
20
  nltk.download('stopwords')
21
  nltk.download('wordnet')
 
22
 
 
23
  MODEL_ID = (
24
  os.environ.get("MODEL_ID")
25
  or os.environ.get("HF_MODEL_ID")
 
30
 
31
 
32
  # ============================================================================
33
+ # TEXT PREPROCESSING CLASS (FOR ANALYSIS ONLY, NOT FOR MODEL INPUT)
34
  # ============================================================================
35
  class TextPreprocessor:
36
+ """NLP preprocessing for analysis and feature extraction"""
37
 
38
  def __init__(self):
39
  self.stemmer = PorterStemmer()
 
42
 
43
  def tokenize(self, text: str) -> List[str]:
44
  """Break text into tokens"""
45
+ return word_tokenize(text.lower())
 
46
 
47
  def remove_stopwords(self, tokens: List[str]) -> List[str]:
48
  """Remove common stop words"""
49
  return [token for token in tokens if token.isalnum() and token not in self.stop_words]
50
 
51
  def stem(self, tokens: List[str]) -> List[str]:
52
+ """Reduce tokens to stems"""
53
  return [self.stemmer.stem(token) for token in tokens]
54
 
55
  def lemmatize(self, tokens: List[str]) -> List[str]:
 
57
  return [self.lemmatizer.lemmatize(token) for token in tokens]
58
 
59
  def sentiment_analysis(self, text: str) -> Dict:
60
+ """Analyze sentiment and phishing indicators"""
61
  blob = TextBlob(text)
62
+ polarity = blob.sentiment.polarity
63
+ subjectivity = blob.sentiment.subjectivity
64
 
 
65
  phishing_indicators = {
66
  "urgent_words": bool(re.search(r'\b(urgent|immediate|act now|verify|confirm|update|click|verify account)\b', text, re.IGNORECASE)),
67
  "threat_words": bool(re.search(r'\b(suspend|limited|expire|locked|disabled|restricted)\b', text, re.IGNORECASE)),
68
+ "suspicious_urls": bool(re.search(r'http\S+|www\S+', text)),
69
  "urgency_level": "HIGH" if re.search(r'\b(urgent|immediate|act now)\b', text, re.IGNORECASE) else "LOW"
70
  }
71
 
 
77
  "phishing_indicators": phishing_indicators
78
  }
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def preprocess(self, text: str) -> Dict:
81
+ """Preprocessing for analysis (NOT for model)"""
82
+ tokens = self.tokenize(text)
 
 
 
 
 
 
83
  tokens_no_stop = self.remove_stopwords(tokens)
 
 
84
  stemmed = self.stem(tokens_no_stop)
 
 
85
  lemmatized = self.lemmatize(tokens_no_stop)
 
 
86
  sentiment = self.sentiment_analysis(text)
87
 
88
  return {
89
  "original_text": text,
 
90
  "tokens": tokens,
91
  "tokens_without_stopwords": tokens_no_stop,
92
  "stemmed_tokens": stemmed,
93
  "lemmatized_tokens": lemmatized,
 
94
  "sentiment": sentiment,
95
  "token_count": len(tokens_no_stop)
96
  }
 
132
  # HELPER FUNCTIONS
133
  # ============================================================================
134
  def _normalize_label_text_only(txt: str) -> str:
135
+ """Normalize model label text"""
 
 
 
136
  t = (str(txt) if txt is not None else "").strip().upper()
137
  if t in ("PHISHING", "PHISH", "SPAM"):
138
  return "PHISH"
 
142
 
143
 
144
  def _load_model():
145
+ """Load model, tokenizer, and preprocessor"""
146
  global _tokenizer, _model, _device, _NORM_LABELS_BY_IDX, _preprocessor
147
 
148
  if _tokenizer is None or _model is None:
 
152
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
153
  _model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
154
  _model.to(_device)
155
+ _model.eval()
156
  _preprocessor = TextPreprocessor()
157
 
158
+ # Warm-up
159
  with torch.no_grad():
160
  _ = _model(
161
  **_tokenizer(["warm up"], return_tensors="pt", padding=True, truncation=True, max_length=512)
162
  .to(_device)
163
  ).logits
164
 
165
+ # Read and normalize model labels
166
  id2label = getattr(_model.config, "id2label", {}) or {}
167
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
168
  _NORM_LABELS_BY_IDX = [_normalize_label_text_only(id2label.get(i, f"LABEL_{i}")) for i in range(num_labels)]
169
 
170
+ print(f"Model loaded successfully")
171
+ print(f"ID2Label: {id2label}")
172
  print(f"Normalized labels: {_NORM_LABELS_BY_IDX}")
173
 
174
 
175
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
176
  """
177
+ Predict using ORIGINAL text (NO cleaning).
178
+ Preprocessing is for ANALYSIS only, not for model input.
179
  """
180
  _load_model()
181
  if not texts:
182
  return []
183
 
184
+ # IMPORTANT: Use original text for model, NOT cleaned text!
185
+ model_inputs = texts
186
+
187
+ # Get preprocessing info for analysis
188
  preprocessing_info = None
189
  if include_preprocessing:
190
  preprocessing_info = [_preprocessor.preprocess(text) for text in texts]
 
 
 
 
191
 
192
+ # Tokenize batch for model
193
  enc = _tokenizer(
194
  model_inputs,
195
  return_tensors="pt",
 
199
  )
200
  enc = {k: v.to(_device) for k, v in enc.items()}
201
 
202
+ # Get predictions
203
  with torch.no_grad():
204
  logits = _model(**enc).logits
205
+ probs = torch.softmax(logits, dim=-1)
206
 
 
207
  id2label = getattr(_model.config, "id2label", None) or {}
208
  labels_by_idx_raw = [id2label.get(i, f"LABEL_{i}") for i in range(probs.shape[-1])]
209
  labels_by_idx_norm = [_normalize_label_text_only(lbl) for lbl in labels_by_idx_raw]
 
216
  raw_label = labels_by_idx_raw[idx]
217
  norm_label = labels_by_idx_norm[idx]
218
 
 
219
  prob_map: Dict[str, float] = {}
220
  for j, lbl_norm in enumerate(labels_by_idx_norm):
221
  key = lbl_norm if lbl_norm in ("PHISH", "LEGIT") else f"CLASS_{j}"
 
226
  "raw_label": raw_label,
227
  "is_phish": True if norm_label == "PHISH" else False,
228
  "score": round(float(p[idx].item()), 4),
229
+ "confidence": round(float(p[idx].item()), 4),
230
  "probs": {k: round(v, 4) for k, v in prob_map.items()},
231
  "predicted_index": idx,
 
 
232
  }
233
 
 
234
  if include_preprocessing and preprocessing_info:
235
  output["preprocessing"] = preprocessing_info[i]
236
 
 
245
 
246
  @app.get("/")
247
  def root():
248
+ """Root endpoint"""
249
  _load_model()
250
  return {
251
  "status": "ok",
252
  "model": MODEL_ID,
253
  "device": _device,
254
+ "note": "Model uses ORIGINAL text for predictions. Preprocessing is for analysis only.",
 
 
 
 
 
 
 
255
  }
256
 
257
 
258
  @app.get("/debug/labels")
259
  def debug_labels():
260
+ """View model configuration"""
261
  _load_model()
262
  return {
263
  "id2label": getattr(_model.config, "id2label", {}),
 
270
 
271
  @app.post("/debug/preprocessing")
272
  def debug_preprocessing(payload: PredictPayload):
273
+ """Debug preprocessing output"""
274
  try:
275
  _load_model()
276
  preprocessing = _preprocessor.preprocess(payload.inputs)
 
284
 
285
  @app.post("/predict")
286
  def predict(payload: PredictPayload):
287
+ """Single prediction"""
288
  try:
289
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
290
  return res[0]
 
294
 
295
  @app.post("/predict-batch")
296
  def predict_batch(payload: BatchPredictPayload):
297
+ """Batch predictions"""
298
  try:
299
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
300
  except Exception as e:
 
303
 
304
  @app.post("/evaluate")
305
  def evaluate(payload: EvalPayload):
306
+ """Evaluate on labeled samples"""
 
 
 
307
  try:
308
  texts = [s.text for s in payload.samples]
309
  gts = [(_normalize_label_text_only(s.label) if s.label is not None else None) for s in payload.samples]
 
337
 
338
 
339
  if __name__ == "__main__":
 
340
  import uvicorn
341
  uvicorn.run(app, host="0.0.0.0", port=8000)