Perth0603 commited on
Commit
b418015
·
verified ·
1 Parent(s): e17ff4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -90
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from typing import List, Optional, Dict
3
  import re
4
- import json
5
 
6
  import torch
7
  import nltk
@@ -21,11 +20,8 @@ except LookupError:
21
  nltk.download('stopwords')
22
  nltk.download('wordnet')
23
 
24
- MODEL_ID = (
25
- os.environ.get("MODEL_ID")
26
- or os.environ.get("HF_MODEL_ID")
27
- or "Perth0603/phishing-email-mobilebert"
28
- )
29
 
30
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
31
 
@@ -126,69 +122,11 @@ _tokenizer = None
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
129
- _LABEL_MAPPING = None
130
 
131
 
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
135
- def _load_labels_from_hf():
136
- """Try to load labels.json from HuggingFace model repo"""
137
- try:
138
- from huggingface_hub import hf_hub_download
139
- labels_file = hf_hub_download(repo_id=MODEL_ID, filename="labels.json")
140
- with open(labels_file, 'r') as f:
141
- labels_data = json.load(f)
142
- return labels_data.get("id2label", {})
143
- except Exception as e:
144
- print(f"[WARNING] Could not load labels.json from HF: {e}")
145
- return None
146
-
147
-
148
- def _get_label_mapping():
149
- """Get complete label mapping with multiple fallback strategies"""
150
- global _model
151
-
152
- if _model is None:
153
- return None
154
-
155
- # Strategy 1: Try model config
156
- id2label = getattr(_model.config, "id2label", {}) or {}
157
- num_labels = int(getattr(_model.config, "num_labels", 2) or 2)
158
-
159
- print(f"[DEBUG] Model config id2label: {id2label}")
160
- print(f"[DEBUG] Model config num_labels: {num_labels}")
161
-
162
- # Strategy 2: If incomplete, try labels.json from HuggingFace
163
- if len(id2label) < num_labels:
164
- print(f"[WARNING] Incomplete id2label in config! Trying labels.json...")
165
- hf_labels = _load_labels_from_hf()
166
- if hf_labels and len(hf_labels) >= num_labels:
167
- id2label = hf_labels
168
- print(f"[SUCCESS] Loaded labels from labels.json: {id2label}")
169
-
170
- # Strategy 3: Convert string keys to int keys
171
- complete_mapping = {}
172
- for i in range(num_labels):
173
- if str(i) in id2label:
174
- complete_mapping[i] = id2label[str(i)]
175
- elif i in id2label:
176
- complete_mapping[i] = id2label[i]
177
- else:
178
- complete_mapping[i] = f"LABEL_{i}"
179
-
180
- # Strategy 4: Final fallback if still incomplete
181
- if len(complete_mapping) < num_labels or any(v.startswith("LABEL_") for v in complete_mapping.values()):
182
- print(f"[WARNING] Using hardcoded fallback mapping!")
183
- complete_mapping = {
184
- 0: "LEGIT",
185
- 1: "PHISH"
186
- }
187
-
188
- print(f"[FINAL] Applied label mapping: {complete_mapping}")
189
- return complete_mapping
190
-
191
-
192
  def _normalize_label(txt: str) -> str:
193
  """Normalize label text"""
194
  t = (str(txt) if txt is not None else "").strip().upper()
@@ -201,7 +139,7 @@ def _normalize_label(txt: str) -> str:
201
 
202
  def _load_model():
203
  """Load model, tokenizer, and preprocessor"""
204
- global _tokenizer, _model, _device, _preprocessor, _LABEL_MAPPING
205
 
206
  if _tokenizer is None or _model is None:
207
  _device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -216,9 +154,6 @@ def _load_model():
216
  _model.eval()
217
  _preprocessor = TextPreprocessor()
218
 
219
- # Get label mapping with fallbacks
220
- _LABEL_MAPPING = _get_label_mapping()
221
-
222
  # Warm-up
223
  with torch.no_grad():
224
  _ = _model(
@@ -226,7 +161,10 @@ def _load_model():
226
  .to(_device)
227
  ).logits
228
 
229
- print(f"Model loaded successfully!\n{'='*60}\n")
 
 
 
230
 
231
 
232
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
@@ -255,34 +193,33 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
255
  logits = _model(**enc).logits
256
  probs = torch.softmax(logits, dim=-1)
257
 
258
- num_labels = probs.shape[-1]
 
259
 
260
  outputs: List[Dict] = []
261
  for text_idx in range(probs.shape[0]):
262
  p = probs[text_idx]
263
 
264
- # Build probability breakdown
265
- prob_breakdown = {}
266
- for class_idx in range(num_labels):
267
- class_label = _LABEL_MAPPING.get(class_idx, f"CLASS_{class_idx}")
268
- class_prob = float(p[class_idx].item())
269
- prob_breakdown[class_label] = round(class_prob, 4)
270
-
271
  # Get prediction
272
  predicted_idx = int(torch.argmax(p).item())
273
- predicted_label_raw = _LABEL_MAPPING.get(predicted_idx, f"CLASS_{predicted_idx}")
274
  predicted_label_norm = _normalize_label(predicted_label_raw)
275
  predicted_prob = float(p[predicted_idx].item())
276
 
 
 
 
 
 
 
277
  output = {
278
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
279
- "predicted_class_index": predicted_idx,
280
  "label": predicted_label_norm,
281
  "raw_label": predicted_label_raw,
282
  "is_phish": predicted_label_norm == "PHISH",
283
- "score": round(predicted_prob, 4),
284
  "confidence": round(predicted_prob * 100, 2),
285
- "probs_by_class": prob_breakdown,
 
286
  }
287
 
288
  if include_preprocessing and preprocessing_info:
@@ -305,7 +242,6 @@ def root():
305
  "status": "ok",
306
  "model": MODEL_ID,
307
  "device": _device,
308
- "label_mapping": _LABEL_MAPPING,
309
  }
310
 
311
 
@@ -314,16 +250,12 @@ def debug_labels():
314
  """View model configuration"""
315
  _load_model()
316
 
317
- id2label_raw = getattr(_model.config, "id2label", {}) or {}
318
- label2id_raw = getattr(_model.config, "label2id", {}) or {}
319
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
320
-
321
  return {
322
  "status": "ok",
323
- "model_config_id2label": id2label_raw,
324
- "model_config_label2id": label2id_raw,
325
- "model_config_num_labels": num_labels,
326
- "applied_mapping": _LABEL_MAPPING,
327
  "device": _device,
328
  }
329
 
 
1
  import os
2
  from typing import List, Optional, Dict
3
  import re
 
4
 
5
  import torch
6
  import nltk
 
20
  nltk.download('stopwords')
21
  nltk.download('wordnet')
22
 
23
+ # CHANGE THIS TO POINT TO YOUR MODEL REPOSITORY
24
+ MODEL_ID = "Perth0603/phishing-email-mobilebert" # ← Your model storage repo
 
 
 
25
 
26
  app = FastAPI(title="Phishing Text Classifier with Preprocessing", version="1.0.0")
27
 
 
122
  _model = None
123
  _device = "cpu"
124
  _preprocessor = None
 
125
 
126
 
127
  # ============================================================================
128
  # HELPER FUNCTIONS
129
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def _normalize_label(txt: str) -> str:
131
  """Normalize label text"""
132
  t = (str(txt) if txt is not None else "").strip().upper()
 
139
 
140
  def _load_model():
141
  """Load model, tokenizer, and preprocessor"""
142
+ global _tokenizer, _model, _device, _preprocessor
143
 
144
  if _tokenizer is None or _model is None:
145
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
154
  _model.eval()
155
  _preprocessor = TextPreprocessor()
156
 
 
 
 
157
  # Warm-up
158
  with torch.no_grad():
159
  _ = _model(
 
161
  .to(_device)
162
  ).logits
163
 
164
+ # Check label mapping
165
+ id2label = getattr(_model.config, "id2label", {})
166
+ print(f"Model labels: {id2label}")
167
+ print(f"{'='*60}\n")
168
 
169
 
170
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
 
193
  logits = _model(**enc).logits
194
  probs = torch.softmax(logits, dim=-1)
195
 
196
+ # Get labels from model config
197
+ id2label = getattr(_model.config, "id2label", {0: "LEGIT", 1: "PHISH"})
198
 
199
  outputs: List[Dict] = []
200
  for text_idx in range(probs.shape[0]):
201
  p = probs[text_idx]
202
 
 
 
 
 
 
 
 
203
  # Get prediction
204
  predicted_idx = int(torch.argmax(p).item())
205
+ predicted_label_raw = id2label.get(predicted_idx, f"CLASS_{predicted_idx}")
206
  predicted_label_norm = _normalize_label(predicted_label_raw)
207
  predicted_prob = float(p[predicted_idx].item())
208
 
209
+ # Build probability breakdown
210
+ prob_breakdown = {}
211
+ for i in range(len(p)):
212
+ label = _normalize_label(id2label.get(i, f"CLASS_{i}"))
213
+ prob_breakdown[label] = round(float(p[i].item()), 4)
214
+
215
  output = {
216
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
 
217
  "label": predicted_label_norm,
218
  "raw_label": predicted_label_raw,
219
  "is_phish": predicted_label_norm == "PHISH",
 
220
  "confidence": round(predicted_prob * 100, 2),
221
+ "score": round(predicted_prob, 4),
222
+ "probs": prob_breakdown,
223
  }
224
 
225
  if include_preprocessing and preprocessing_info:
 
242
  "status": "ok",
243
  "model": MODEL_ID,
244
  "device": _device,
 
245
  }
246
 
247
 
 
250
  """View model configuration"""
251
  _load_model()
252
 
 
 
 
 
253
  return {
254
  "status": "ok",
255
+ "model_id": MODEL_ID,
256
+ "id2label": getattr(_model.config, "id2label", {}),
257
+ "label2id": getattr(_model.config, "label2id", {}),
258
+ "num_labels": int(getattr(_model.config, "num_labels", 0)),
259
  "device": _device,
260
  }
261