Perth0603 commited on
Commit
72eb3f5
·
verified ·
1 Parent(s): 113b42d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -1
app.py CHANGED
@@ -123,4 +123,122 @@ def _predict_texts(texts: List[str]) -> List[Dict]:
123
  norm_label = _normalize_label(raw_label)
124
 
125
  # Also expose per-label probabilities (normalized names where possible)
126
- prob_map = {_normalize_label(labels_by
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  norm_label = _normalize_label(raw_label)
124
 
125
  # Also expose per-label probabilities (normalized names where possible)
126
+ prob_map = {_normalize_label(labels_by_idx[j]): float(p[j].item()) for j in range(len(labels_by_idx))}
127
+
128
+ # Map to your dataset convention: PHISH=1, LEGIT=0
129
+ ds_label = None
130
+ if _IDX_PHISH is not None and _IDX_LEGIT is not None:
131
+ if idx == _IDX_PHISH:
132
+ ds_label = 1
133
+ elif idx == _IDX_LEGIT:
134
+ ds_label = 0
135
+
136
+ # Per-dataset-label probabilities when both indices are known
137
+ probs_by_dataset = None
138
+ if _IDX_PHISH is not None and _IDX_LEGIT is not None:
139
+ probs_by_dataset = {
140
+ "1": float(p[_IDX_PHISH].item()), # PHISH
141
+ "0": float(p[_IDX_LEGIT].item()), # LEGIT
142
+ }
143
+
144
+ outputs.append(
145
+ {
146
+ "label": norm_label, # normalized (e.g., PHISH/LEGIT)
147
+ "raw_label": raw_label, # from model.config.id2label
148
+ "score": float(p[idx].item()), # max class probability
149
+ "probs": prob_map, # dict of normalized label -> probability
150
+ "predicted_index": idx, # model argmax index
151
+ "predicted_dataset_label": ds_label, # 1 for PHISH, 0 for LEGIT (your convention)
152
+ "probs_by_dataset_label": probs_by_dataset,
153
+ }
154
+ )
155
+
156
+ return outputs
157
+
158
+
159
+ @app.get("/")
160
+ def root():
161
+ return {"status": "ok", "model": MODEL_ID}
162
+
163
+
164
+ @app.get("/debug/labels")
165
+ def debug_labels():
166
+ _load_model()
167
+ return {
168
+ "id2label": getattr(_model.config, "id2label", {}),
169
+ "label2id": getattr(_model.config, "label2id", {}),
170
+ "num_labels": int(getattr(_model.config, "num_labels", 0)),
171
+ "device": _device,
172
+ "norm_labels_by_idx": _NORM_LABELS_BY_IDX,
173
+ "idx_phish": _IDX_PHISH,
174
+ "idx_legit": _IDX_LEGIT,
175
+ }
176
+
177
+
178
+ @app.post("/predict")
179
+ def predict(payload: PredictPayload):
180
+ try:
181
+ res = _predict_texts([payload.inputs])
182
+ return res[0]
183
+ except Exception as e:
184
+ raise HTTPException(status_code=500, detail=f"Prediction error: {e}")
185
+
186
+
187
+ @app.post("/predict-batch")
188
+ def predict_batch(payload: BatchPredictPayload):
189
+ try:
190
+ return _predict_texts(payload.inputs)
191
+ except Exception as e:
192
+ raise HTTPException(status_code=500, detail=f"Batch prediction error: {e}")
193
+
194
+
195
+ @app.post("/evaluate")
196
+ def evaluate(payload: EvalPayload):
197
+ """
198
+ Quick on-the-spot test with provided labeled samples.
199
+
200
+ Request body:
201
+ {
202
+ "samples": [
203
+ {"text": "Your parcel is held...", "label": "PHISH"}, # or "1"
204
+ {"text": "Lunch at 12?", "label": "LEGIT"} # or "0"
205
+ ]
206
+ }
207
+
208
+ Returns accuracy and per-class counts.
209
+ """
210
+ try:
211
+ texts = [s.text for s in payload.samples]
212
+ gts = [(_normalize_label(s.label) if s.label is not None else None) for s in payload.samples]
213
+ preds = _predict_texts(texts)
214
+
215
+ total = len(preds)
216
+ correct = 0
217
+ per_class: Dict[str, Dict[str, int]] = {}
218
+
219
+ for gt, pr in zip(gts, preds):
220
+ pred_label = pr["label"]
221
+ if gt is not None:
222
+ correct += int(gt == pred_label)
223
+ per_class.setdefault(gt, {"tp": 0, "count": 0})
224
+ per_class[gt]["count"] += 1
225
+ if gt == pred_label:
226
+ per_class[gt]["tp"] += 1
227
+
228
+ has_gts = any(gt is not None for gt in gts)
229
+ acc = (correct / sum(1 for gt in gts if gt is not None)) if has_gts else None
230
+
231
+ return {
232
+ "accuracy": acc, # None if no ground truths provided
233
+ "total": total,
234
+ "predictions": preds,
235
+ "per_class": per_class,
236
+ }
237
+ except Exception as e:
238
+ raise HTTPException(status_code=500, detail=f"Evaluation error: {e}")
239
+
240
+
241
+ if __name__ == "__main__":
242
+ # Run: uvicorn app:app --host 0.0.0.0 --port 8000 --reload
243
+ import uvicorn
244
+ uvicorn.run(app, host="0.0.0.0", port=8000)