Perth0603 commited on
Commit
1c170d1
·
verified ·
1 Parent(s): 4e51678

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -41
app.py CHANGED
@@ -132,20 +132,42 @@ _LABEL_MAPPING = None
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def _get_label_mapping():
136
- """Get complete label mapping from model config"""
137
  global _model
138
 
139
  if _model is None:
140
  return None
141
 
 
142
  id2label = getattr(_model.config, "id2label", {}) or {}
143
- num_labels = int(getattr(_model.config, "num_labels", 0) or 0)
 
 
 
144
 
145
- print(f"[DEBUG] Raw id2label from config: {id2label}")
146
- print(f"[DEBUG] num_labels: {num_labels}")
 
 
 
 
 
147
 
148
- # Build complete mapping by index
149
  complete_mapping = {}
150
  for i in range(num_labels):
151
  if str(i) in id2label:
@@ -155,15 +177,15 @@ def _get_label_mapping():
155
  else:
156
  complete_mapping[i] = f"LABEL_{i}"
157
 
158
- # If incomplete, use fallback
159
- if len(complete_mapping) < num_labels:
160
- print(f"[WARNING] Incomplete mapping! Using fallback.")
161
  complete_mapping = {
162
  0: "LEGIT",
163
  1: "PHISH"
164
  }
165
 
166
- print(f"[DEBUG] Complete mapping applied: {complete_mapping}")
167
  return complete_mapping
168
 
169
 
@@ -184,8 +206,8 @@ def _load_model():
184
  if _tokenizer is None or _model is None:
185
  _device = "cuda" if torch.cuda.is_available() else "cpu"
186
  print(f"\n{'='*60}")
187
- print(f"Loading model on device: {_device}")
188
- print(f"Model ID: {MODEL_ID}")
189
  print(f"{'='*60}\n")
190
 
191
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -194,7 +216,7 @@ def _load_model():
194
  _model.eval()
195
  _preprocessor = TextPreprocessor()
196
 
197
- # Get label mapping
198
  _LABEL_MAPPING = _get_label_mapping()
199
 
200
  # Warm-up
@@ -204,14 +226,11 @@ def _load_model():
204
  .to(_device)
205
  ).logits
206
 
207
- print(f"{'='*60}\n")
208
 
209
 
210
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
211
- """
212
- Predict with correct label index mapping
213
- CRITICAL: probs[i][j] where j is the CLASS INDEX, not probability value
214
- """
215
  _load_model()
216
  if not texts:
217
  return []
@@ -237,31 +256,23 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
237
  probs = torch.softmax(logits, dim=-1)
238
 
239
  num_labels = probs.shape[-1]
240
- print(f"\n[DEBUG] num_labels from probs shape: {num_labels}")
241
 
242
  outputs: List[Dict] = []
243
  for text_idx in range(probs.shape[0]):
244
- p = probs[text_idx] # Get probabilities for this text: shape [num_labels]
245
 
246
- # Create probability breakdown for ALL classes
247
  prob_breakdown = {}
248
- all_probs_list = []
249
-
250
  for class_idx in range(num_labels):
251
- class_prob = float(p[class_idx].item())
252
  class_label = _LABEL_MAPPING.get(class_idx, f"CLASS_{class_idx}")
 
253
  prob_breakdown[class_label] = round(class_prob, 4)
254
- all_probs_list.append(class_prob)
255
- print(f"[DEBUG] Class {class_idx} ({class_label}): {round(class_prob, 4)}")
256
 
257
- # Get argmax index
258
  predicted_idx = int(torch.argmax(p).item())
259
  predicted_label_raw = _LABEL_MAPPING.get(predicted_idx, f"CLASS_{predicted_idx}")
260
  predicted_label_norm = _normalize_label(predicted_label_raw)
261
  predicted_prob = float(p[predicted_idx].item())
262
-
263
- print(f"[DEBUG] ARGMAX: index={predicted_idx}, label={predicted_label_raw}, prob={round(predicted_prob, 4)}")
264
- print(f"[DEBUG] Normalized label: {predicted_label_norm}")
265
 
266
  output = {
267
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
@@ -272,14 +283,12 @@ def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List
272
  "score": round(predicted_prob, 4),
273
  "confidence": round(predicted_prob * 100, 2),
274
  "probs_by_class": prob_breakdown,
275
- "all_probs_raw": [round(p_val, 4) for p_val in all_probs_list],
276
  }
277
 
278
  if include_preprocessing and preprocessing_info:
279
  output["preprocessing"] = preprocessing_info[text_idx]
280
 
281
  outputs.append(output)
282
- print(f"\n")
283
 
284
  return outputs
285
 
@@ -302,7 +311,7 @@ def root():
302
 
303
  @app.get("/debug/labels")
304
  def debug_labels():
305
- """View complete model configuration"""
306
  _load_model()
307
 
308
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
@@ -316,7 +325,6 @@ def debug_labels():
316
  "model_config_num_labels": num_labels,
317
  "applied_mapping": _LABEL_MAPPING,
318
  "device": _device,
319
- "note": "applied_mapping is what gets used for predictions"
320
  }
321
 
322
 
@@ -326,12 +334,9 @@ def debug_preprocessing(payload: PredictPayload):
326
  try:
327
  _load_model()
328
  preprocessing = _preprocessor.preprocess(payload.inputs)
329
- return {
330
- "status": "ok",
331
- "preprocessing": preprocessing
332
- }
333
  except Exception as e:
334
- raise HTTPException(status_code=500, detail=f"Error: {e}")
335
 
336
 
337
  @app.post("/predict")
@@ -341,7 +346,7 @@ def predict(payload: PredictPayload):
341
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
342
  return res[0]
343
  except Exception as e:
344
- raise HTTPException(status_code=500, detail=f"Error: {e}")
345
 
346
 
347
  @app.post("/predict-batch")
@@ -350,7 +355,7 @@ def predict_batch(payload: BatchPredictPayload):
350
  try:
351
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
352
  except Exception as e:
353
- raise HTTPException(status_code=500, detail=f"Error: {e}")
354
 
355
 
356
  @app.post("/evaluate")
@@ -385,7 +390,7 @@ def evaluate(payload: EvalPayload):
385
  "per_class": per_class,
386
  }
387
  except Exception as e:
388
- raise HTTPException(status_code=500, detail=f"Error: {e}")
389
 
390
 
391
  if __name__ == "__main__":
 
132
  # ============================================================================
133
  # HELPER FUNCTIONS
134
  # ============================================================================
135
+ def _load_labels_from_hf():
136
+ """Try to load labels.json from HuggingFace model repo"""
137
+ try:
138
+ from huggingface_hub import hf_hub_download
139
+ labels_file = hf_hub_download(repo_id=MODEL_ID, filename="labels.json")
140
+ with open(labels_file, 'r') as f:
141
+ labels_data = json.load(f)
142
+ return labels_data.get("id2label", {})
143
+ except Exception as e:
144
+ print(f"[WARNING] Could not load labels.json from HF: {e}")
145
+ return None
146
+
147
+
148
  def _get_label_mapping():
149
+ """Get complete label mapping with multiple fallback strategies"""
150
  global _model
151
 
152
  if _model is None:
153
  return None
154
 
155
+ # Strategy 1: Try model config
156
  id2label = getattr(_model.config, "id2label", {}) or {}
157
+ num_labels = int(getattr(_model.config, "num_labels", 2) or 2)
158
+
159
+ print(f"[DEBUG] Model config id2label: {id2label}")
160
+ print(f"[DEBUG] Model config num_labels: {num_labels}")
161
 
162
+ # Strategy 2: If incomplete, try labels.json from HuggingFace
163
+ if len(id2label) < num_labels:
164
+ print(f"[WARNING] Incomplete id2label in config! Trying labels.json...")
165
+ hf_labels = _load_labels_from_hf()
166
+ if hf_labels and len(hf_labels) >= num_labels:
167
+ id2label = hf_labels
168
+ print(f"[SUCCESS] Loaded labels from labels.json: {id2label}")
169
 
170
+ # Strategy 3: Convert string keys to int keys
171
  complete_mapping = {}
172
  for i in range(num_labels):
173
  if str(i) in id2label:
 
177
  else:
178
  complete_mapping[i] = f"LABEL_{i}"
179
 
180
+ # Strategy 4: Final fallback if still incomplete
181
+ if len(complete_mapping) < num_labels or any(v.startswith("LABEL_") for v in complete_mapping.values()):
182
+ print(f"[WARNING] Using hardcoded fallback mapping!")
183
  complete_mapping = {
184
  0: "LEGIT",
185
  1: "PHISH"
186
  }
187
 
188
+ print(f"[FINAL] Applied label mapping: {complete_mapping}")
189
  return complete_mapping
190
 
191
 
 
206
  if _tokenizer is None or _model is None:
207
  _device = "cuda" if torch.cuda.is_available() else "cpu"
208
  print(f"\n{'='*60}")
209
+ print(f"Loading model: {MODEL_ID}")
210
+ print(f"Device: {_device}")
211
  print(f"{'='*60}\n")
212
 
213
  _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
216
  _model.eval()
217
  _preprocessor = TextPreprocessor()
218
 
219
+ # Get label mapping with fallbacks
220
  _LABEL_MAPPING = _get_label_mapping()
221
 
222
  # Warm-up
 
226
  .to(_device)
227
  ).logits
228
 
229
+ print(f"Model loaded successfully!\n{'='*60}\n")
230
 
231
 
232
  def _predict_texts(texts: List[str], include_preprocessing: bool = True) -> List[Dict]:
233
+ """Predict with correct label mapping"""
 
 
 
234
  _load_model()
235
  if not texts:
236
  return []
 
256
  probs = torch.softmax(logits, dim=-1)
257
 
258
  num_labels = probs.shape[-1]
 
259
 
260
  outputs: List[Dict] = []
261
  for text_idx in range(probs.shape[0]):
262
+ p = probs[text_idx]
263
 
264
+ # Build probability breakdown
265
  prob_breakdown = {}
 
 
266
  for class_idx in range(num_labels):
 
267
  class_label = _LABEL_MAPPING.get(class_idx, f"CLASS_{class_idx}")
268
+ class_prob = float(p[class_idx].item())
269
  prob_breakdown[class_label] = round(class_prob, 4)
 
 
270
 
271
+ # Get prediction
272
  predicted_idx = int(torch.argmax(p).item())
273
  predicted_label_raw = _LABEL_MAPPING.get(predicted_idx, f"CLASS_{predicted_idx}")
274
  predicted_label_norm = _normalize_label(predicted_label_raw)
275
  predicted_prob = float(p[predicted_idx].item())
 
 
 
276
 
277
  output = {
278
  "text": texts[text_idx][:100] + "..." if len(texts[text_idx]) > 100 else texts[text_idx],
 
283
  "score": round(predicted_prob, 4),
284
  "confidence": round(predicted_prob * 100, 2),
285
  "probs_by_class": prob_breakdown,
 
286
  }
287
 
288
  if include_preprocessing and preprocessing_info:
289
  output["preprocessing"] = preprocessing_info[text_idx]
290
 
291
  outputs.append(output)
 
292
 
293
  return outputs
294
 
 
311
 
312
  @app.get("/debug/labels")
313
  def debug_labels():
314
+ """View model configuration"""
315
  _load_model()
316
 
317
  id2label_raw = getattr(_model.config, "id2label", {}) or {}
 
325
  "model_config_num_labels": num_labels,
326
  "applied_mapping": _LABEL_MAPPING,
327
  "device": _device,
 
328
  }
329
 
330
 
 
334
  try:
335
  _load_model()
336
  preprocessing = _preprocessor.preprocess(payload.inputs)
337
+ return preprocessing
 
 
 
338
  except Exception as e:
339
+ raise HTTPException(status_code=500, detail=str(e))
340
 
341
 
342
  @app.post("/predict")
 
346
  res = _predict_texts([payload.inputs], include_preprocessing=payload.include_preprocessing)
347
  return res[0]
348
  except Exception as e:
349
+ raise HTTPException(status_code=500, detail=str(e))
350
 
351
 
352
  @app.post("/predict-batch")
 
355
  try:
356
  return _predict_texts(payload.inputs, include_preprocessing=payload.include_preprocessing)
357
  except Exception as e:
358
+ raise HTTPException(status_code=500, detail=str(e))
359
 
360
 
361
  @app.post("/evaluate")
 
390
  "per_class": per_class,
391
  }
392
  except Exception as e:
393
+ raise HTTPException(status_code=500, detail=str(e))
394
 
395
 
396
  if __name__ == "__main__":