Perth0603 commited on
Commit
847316c
·
verified ·
1 Parent(s): 7e1eb79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -54
app.py CHANGED
@@ -126,11 +126,47 @@ _tokenizer = None
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
 
129
 
130
 
131
  # ============================================================================
132
  # HELPER FUNCTIONS
133
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def _normalize_label(txt: str) -> str:
135
  """Normalize label text"""
136
  t = (str(txt) if txt is not None else "").strip().upper()
@@ -143,7 +179,7 @@ def _normalize_label(txt: str) -> str:
143
 
144
  def _load_model():
145
  """Load model, tokenizer, and preprocessor"""
146
- global _tokenizer, _model, _device, _preprocessor
147
 
148
  if _tokenizer is None or _model is None:
149
  _device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -158,6 +194,9 @@ def _load_model():
158
  _model.eval()
159
  _preprocessor = TextPreprocessor()
160
 
 
 
 
161
  # Warm-up
162
  with torch.no_grad():
163
  _ = _model(
@@ -165,18 +204,13 @@ def _load_model():
165
  .to(_device)
166
  ).logits
167
 
168
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
169
- id2label = getattr(_model.config, "id2label", {}) or {}
170
-
171
- print(f"Number of labels: {num_labels}")
172
- print(f"Label mapping: {id2label}")
173
  print(f"{'='*60}\n")
174
 
175
 
176
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
177
  """
178
- Predict with CORRECT label indexing.
179
- Index 0 = LEGIT, Index 1 = PHISH
180
  """
181
  _load_model()
182
  if not texts:
@@ -202,49 +236,50 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
202
  logits = _model(**enc).logits
203
  probs = torch.softmax(logits, dim=-1)
204
 
205
- # CORRECT LABEL MAPPING
206
- # Index 0 = LEGIT (probs[i][0])
207
- # Index 1 = PHISH (probs[i][1])
208
- labels_by_idx = ["LEGIT", "PHISH"]
209
 
210
  outputs: List[Dict] = []
211
- for i in range(probs.shape[0]):
212
- p = probs[i]
213
 
214
- # Get probabilities for each class
215
- prob_legit = float(p[0].item())
216
- prob_phish = float(p[1].item())
217
 
218
- # Determine prediction based on which is higher
219
- if prob_phish > prob_legit:
220
- predicted_label = "PHISH"
221
- predicted_idx = 1
222
- confidence = prob_phish
223
- else:
224
- predicted_label = "LEGIT"
225
- predicted_idx = 0
226
- confidence = prob_legit
 
 
 
 
 
 
227
 
228
  output = {
229
- "text": texts[i][:100] + "..." if len(texts[i]) > 100 else texts[i],
230
- "label": predicted_label,
231
- "is_phish": predicted_label == "PHISH",
232
- "confidence": round(confidence * 100, 2), # Convert to percentage
233
- "predicted_index": predicted_idx,
234
- "probs": {
235
- "LEGIT": round(prob_legit * 100, 2),
236
- "PHISH": round(prob_phish * 100, 2),
237
- },
238
- "raw_probs": {
239
- "LEGIT (index 0)": round(prob_legit, 4),
240
- "PHISH (index 1)": round(prob_phish, 4),
241
- }
242
  }
243
 
244
  if include_preprocessing and preprocessing_info:
245
- output["preprocessing"] = preprocessing_info[i]
246
 
247
  outputs.append(output)
 
248
 
249
  return outputs
250
 
@@ -261,17 +296,13 @@ def root():
261
  "status": "ok",
262
  "model": MODEL_ID,
263
  "device": _device,
264
- "label_mapping": {
265
- "0": "LEGIT",
266
- "1": "PHISH"
267
- },
268
- "note": "Index 0 = LEGIT (probability%), Index 1 = PHISH (probability%)"
269
  }
270
 
271
 
272
  @app.get("/debug/labels")
273
  def debug_labels():
274
- """View model configuration"""
275
  _load_model()
276
 
277
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
@@ -280,14 +311,12 @@ def debug_labels():
280
 
281
  return {
282
  "status": "ok",
283
- "config_id2label": id2label_raw,
284
- "config_label2id": label2id_raw,
285
- "config_num_labels": num_labels,
286
- "applied_mapping": {
287
- "0": "LEGIT",
288
- "1": "PHISH"
289
- },
290
- "device": _device
291
  }
292
 
293
 
 
126
  _model = None
127
  _device = "cpu"
128
  _preprocessor = None
129
+ _LABEL_MAPPING = None
130
 
131
 
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
135
+ def _get_label_mapping():
136
+ """Get complete label mapping from model config"""
137
+ global _model
138
+
139
+ if _model is None:
140
+ return None
141
+
142
+ id2label = getattr(_model.config, "id2label", {}) or {}
143
+ num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
144
+
145
+ print(f"[DEBUG] Raw id2label from config: {id2label}")
146
+ print(f"[DEBUG] num_labels: {num_labels}")
147
+
148
+ # Build complete mapping by index
149
+ complete_mapping = {}
150
+ for i in range(num_labels):
151
+ if str(i) in id2label:
152
+ complete_mapping[i] = id2label[str(i)]
153
+ elif i in id2label:
154
+ complete_mapping[i] = id2label[i]
155
+ else:
156
+ complete_mapping[i] = f"LABEL_{i}"
157
+
158
+ # If incomplete, use fallback
159
+ if len(complete_mapping) < num_labels:
160
+ print(f"[WARNING] Incomplete mapping! Using fallback.")
161
+ complete_mapping = {
162
+ 0: "LEGIT",
163
+ 1: "PHISH"
164
+ }
165
+
166
+ print(f"[DEBUG] Complete mapping applied: {complete_mapping}")
167
+ return complete_mapping
168
+
169
+
170
  def _normalize_label(txt: str) -> str:
171
  """Normalize label text"""
172
  t = (str(txt) if txt is not None else "").strip().upper()
 
179
 
180
  def _load_model():
181
  """Load model, tokenizer, and preprocessor"""
182
+ global _tokenizer, _model, _device, _preprocessor, _LABEL_MAPPING
183
 
184
  if _tokenizer is None or _model is None:
185
  _device = "cuda" if torch.cuda.is_available() else "cpu"
 
194
  _model.eval()
195
  _preprocessor = TextPreprocessor()
196
 
197
+ # Get label mapping
198
+ _LABEL_MAPPING = _get_label_mapping()
199
+
200
  # Warm-up
201
  with torch.no_grad():
202
  _ = _model(
 
204
  .to(_device)
205
  ).logits
206
 
 
 
 
 
 
207
  print(f"{'='*60}\n")
208
 
209
 
210
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
211
  """
212
+ Predict with correct label index mapping
213
+ CRITICAL: probs[i][j] where j is the CLASS INDEX, not probability value
214
  """
215
  _load_model()
216
  if not texts:
 
236
  logits = _model(**enc).logits
237
  probs = torch.softmax(logits, dim=-1)
238
 
239
+ num_labels = probs.shape[-1]
240
+ print(f"\n[DEBUG] num_labels from probs shape: {num_labels}")
 
 
241
 
242
  outputs: List[Dict] = []
243
+ for text_idx in range(probs.shape[0]):
244
+ p = probs[text_idx] # Get probabilities for this text: shape [num_labels]
245
 
246
+ # Create probability breakdown for ALL classes
247
+ prob_breakdown = {}
248
+ all_probs_list = []
249
 
250
+ for class_idx in range(num_labels):
251
+ class_prob = float(p[class_idx].item())
252
+ class_label = _LABEL_MAPPING.get(class_idx, f"CLASS_{class_idx}")
253
+ prob_breakdown[class_label] = round(class_prob, 4)
254
+ all_probs_list.append(class_prob)
255
+ print(f"[DEBUG] Class {class_idx} ({class_label}): {round(class_prob, 4)}")
256
+
257
+ # Get argmax index
258
+ predicted_idx = int(torch.argmax(p).item())
259
+ predicted_label_raw = _LABEL_MAPPING.get(predicted_idx, f"CLASS_{predicted_idx}")
260
+ predicted_label_norm = _normalize_label(predicted_label_raw)
261
+ predicted_prob = float(p[predicted_idx].item())
262
+
263
+ print(f"[DEBUG] ARGMAX: index={predicted_idx}, label={predicted_label_raw}, prob={round(predicted_prob, 4)}")
264
+ print(f"[DEBUG] Normalized label: {predicted_label_norm}")
265
 
266
  output = {
267
+ "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
268
+ "predicted_class_index": predicted_idx,
269
+ "label": predicted_label_norm,
270
+ "raw_label": predicted_label_raw,
271
+ "is_phish": predicted_label_norm == "PHISH",
272
+ "score": round(predicted_prob, 4),
273
+ "confidence": round(predicted_prob * 100, 2),
274
+ "probs_by_class": prob_breakdown,
275
+ "all_probs_raw": [round(p_val, 4) for p_val in all_probs_list],
 
 
 
 
276
  }
277
 
278
  if include_preprocessing and preprocessing_info:
279
+ output["preprocessing"] = preprocessing_info[text_idx]
280
 
281
  outputs.append(output)
282
+ print(f"\n")
283
 
284
  return outputs
285
 
 
296
  "status": "ok",
297
  "model": MODEL_ID,
298
  "device": _device,
299
+ "label_mapping": _LABEL_MAPPING,
 
 
 
 
300
  }
301
 
302
 
303
  @app.get("/debug/labels")
304
  def debug_labels():
305
+ """View complete model configuration"""
306
  _load_model()
307
 
308
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
 
311
 
312
  return {
313
  "status": "ok",
314
+ "model_config_id2label": id2label_raw,
315
+ "model_config_label2id": label2id_raw,
316
+ "model_config_num_labels": num_labels,
317
+ "applied_mapping": _LABEL_MAPPING,
318
+ "device": _device,
319
+ "note": "applied_mapping is what gets used for predictions"
 
 
320
  }
321
 
322