Perth0603 commited on
Commit
7e1eb79
·
verified ·
1 Parent(s): 6823e29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -85
app.py CHANGED
@@ -126,59 +126,11 @@ _tokenizer = None
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
129
- _LABEL_MAPPING = None
130
 
131
 
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
135
- def _get_label_mapping():
136
- """
137
- Get complete label mapping.
138
- If model config is incomplete, use fallback mapping.
139
- """
140
- global _model, _LABEL_MAPPING
141
-
142
- if _model is None:
143
- return None
144
-
145
- id2label = getattr(_model.config, "id2label", {}) or {}
146
-
147
- # Check if mapping is incomplete (missing label 0)
148
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
149
-
150
- print(f"DEBUG: num_labels = {num_labels}")
151
- print(f"DEBUG: id2label from config = {id2label}")
152
-
153
- # If incomplete, use fallback
154
- if len(id2label) < num_labels:
155
- print(f"WARNING: Incomplete label mapping detected!")
156
- print(f"Expected {num_labels} labels, got {len(id2label)}")
157
-
158
- # Try to load from labels.json if available
159
- try:
160
- import pkg_resources
161
- model_path = pkg_resources.resource_filename(__name__, 'models')
162
- labels_path = os.path.join(model_path, 'labels.json')
163
- if os.path.exists(labels_path):
164
- with open(labels_path, 'r') as f:
165
- labels_data = json.load(f)
166
- id2label = labels_data.get("id2label", {})
167
- print(f"Loaded labels from labels.json: {id2label}")
168
- except:
169
- pass
170
-
171
- # Final fallback mapping
172
- if len(id2label) < 2:
173
- print("Using fallback label mapping: 0=LEGIT, 1=PHISH")
174
- id2label = {
175
- "0": "LEGIT",
176
- "1": "PHISH"
177
- }
178
-
179
- return id2label
180
-
181
-
182
  def _normalize_label(txt: str) -> str:
183
  """Normalize label text"""
184
  t = (str(txt) if txt is not None else "").strip().upper()
@@ -191,7 +143,7 @@ def _normalize_label(txt: str) -> str:
191
 
192
  def _load_model():
193
  """Load model, tokenizer, and preprocessor"""
194
- global _tokenizer, _model, _device, _preprocessor, _LABEL_MAPPING
195
 
196
  if _tokenizer is None or _model is None:
197
  _device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -206,9 +158,6 @@ def _load_model():
206
  _model.eval()
207
  _preprocessor = TextPreprocessor()
208
 
209
- # Get label mapping
210
- _LABEL_MAPPING = _get_label_mapping()
211
-
212
  # Warm-up
213
  with torch.no_grad():
214
  _ = _model(
@@ -217,14 +166,17 @@ def _load_model():
217
  ).logits
218
 
219
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
 
 
220
  print(f"Number of labels: {num_labels}")
221
- print(f"Label mapping: {_LABEL_MAPPING}")
222
  print(f"{'='*60}\n")
223
 
224
 
225
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
226
  """
227
- Predict with corrected label mapping
 
228
  """
229
  _load_model()
230
  if not texts:
@@ -250,39 +202,43 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
250
  logits = _model(**enc).logits
251
  probs = torch.softmax(logits, dim=-1)
252
 
253
- # Build label list from mapping
254
- num_labels = probs.shape[-1]
255
- labels_by_idx = []
256
- for i in range(num_labels):
257
- label = _LABEL_MAPPING.get(str(i), f"LABEL_{i}")
258
- labels_by_idx.append(label)
259
-
260
- print(f"DEBUG: Using labels: {labels_by_idx}")
261
 
262
  outputs: List[Dict] = []
263
  for i in range(probs.shape[0]):
264
  p = probs[i]
265
- idx = int(torch.argmax(p).item())
266
-
267
- raw_label = labels_by_idx[idx]
268
- norm_label = _normalize_label(raw_label)
269
-
270
- # Build probability map
271
- prob_map: Dict[str, float] = {}
272
- for j in range(len(labels_by_idx)):
273
- label_norm = _normalize_label(labels_by_idx[j])
274
- prob_map[label_norm] = float(p[j].item())
 
 
 
 
275
 
276
  output = {
277
  "text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
278
- "label": norm_label,
279
- "raw_label": raw_label,
280
- "is_phish": norm_label == "PHISH",
281
- "score": round(float(p[idx].item()), 4),
282
- "confidence": round(float(p[idx].item()), 4),
283
- "predicted_index": idx,
284
- "probs": {k: round(v, 4) for k, v in prob_map.items()},
285
- "all_probs_raw": [round(float(p[j].item()), 4) for j in range(len(labels_by_idx))],
 
 
 
 
286
  }
287
 
288
  if include_preprocessing and preprocessing_info:
@@ -305,13 +261,17 @@ def root():
305
  "status": "ok",
306
  "model": MODEL_ID,
307
  "device": _device,
308
- "label_mapping": _LABEL_MAPPING,
 
 
 
 
309
  }
310
 
311
 
312
  @app.get("/debug/labels")
313
  def debug_labels():
314
- """View complete model configuration"""
315
  _load_model()
316
 
317
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
@@ -323,9 +283,11 @@ def debug_labels():
323
  "config_id2label": id2label_raw,
324
  "config_label2id": label2id_raw,
325
  "config_num_labels": num_labels,
326
- "applied_label_mapping": _LABEL_MAPPING,
327
- "device": _device,
328
- "note": "If config_id2label is incomplete, applied_label_mapping is used"
 
 
329
  }
330
 
331
 
 
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
 
129
 
130
 
131
  # ============================================================================
132
  # HELPER FUNCTIONS
133
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def _normalize_label(txt: str) -> str:
135
  """Normalize label text"""
136
  t = (str(txt) if txt is not None else "").strip().upper()
 
143
 
144
  def _load_model():
145
  """Load model, tokenizer, and preprocessor"""
146
+ global _tokenizer, _model, _device, _preprocessor
147
 
148
  if _tokenizer is None or _model is None:
149
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
158
  _model.eval()
159
  _preprocessor = TextPreprocessor()
160
 
 
 
 
161
  # Warm-up
162
  with torch.no_grad():
163
  _ = _model(
 
166
  ).logits
167
 
168
  num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
169
+ id2label = getattr(_model.config, "id2label", {}) or {}
170
+
171
  print(f"Number of labels: {num_labels}")
172
+ print(f"Label mapping: {id2label}")
173
  print(f"{'='*60}\n")
174
 
175
 
176
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
177
  """
178
+ Predict with CORRECT label indexing.
179
+ Index 0 = LEGIT, Index 1 = PHISH
180
  """
181
  _load_model()
182
  if not texts:
 
202
  logits = _model(**enc).logits
203
  probs = torch.softmax(logits, dim=-1)
204
 
205
+ # CORRECT LABEL MAPPING
206
+ # Index 0 = LEGIT (probs[i][0])
207
+ # Index 1 = PHISH (probs[i][1])
208
+ labels_by_idx = ["LEGIT", "PHISH"]
 
 
 
 
209
 
210
  outputs: List[Dict] = []
211
  for i in range(probs.shape[0]):
212
  p = probs[i]
213
+
214
+ # Get probabilities for each class
215
+ prob_legit = float(p[0].item())
216
+ prob_phish = float(p[1].item())
217
+
218
+ # Determine prediction based on which is higher
219
+ if prob_phish > prob_legit:
220
+ predicted_label = "PHISH"
221
+ predicted_idx = 1
222
+ confidence = prob_phish
223
+ else:
224
+ predicted_label = "LEGIT"
225
+ predicted_idx = 0
226
+ confidence = prob_legit
227
 
228
  output = {
229
  "text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
230
+ "label": predicted_label,
231
+ "is_phish": predicted_label == "PHISH",
232
+ "confidence": round(confidence * 100, 2), # Convert to percentage
233
+ "predicted_index": predicted_idx,
234
+ "probs": {
235
+ "LEGIT": round(prob_legit * 100, 2),
236
+ "PHISH": round(prob_phish * 100, 2),
237
+ },
238
+ "raw_probs": {
239
+ "LEGIT (index 0)": round(prob_legit, 4),
240
+ "PHISH (index 1)": round(prob_phish, 4),
241
+ }
242
  }
243
 
244
  if include_preprocessing and preprocessing_info:
 
261
  "status": "ok",
262
  "model": MODEL_ID,
263
  "device": _device,
264
+ "label_mapping": {
265
+ "0": "LEGIT",
266
+ "1": "PHISH"
267
+ },
268
+ "note": "Index 0 = LEGIT (probability%), Index 1 = PHISH (probability%)"
269
  }
270
 
271
 
272
  @app.get("/debug/labels")
273
  def debug_labels():
274
+ """View model configuration"""
275
  _load_model()
276
 
277
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
 
283
  "config_id2label": id2label_raw,
284
  "config_label2id": label2id_raw,
285
  "config_num_labels": num_labels,
286
+ "applied_mapping": {
287
+ "0": "LEGIT",
288
+ "1": "PHISH"
289
+ },
290
+ "device": _device
291
  }
292
 
293