Perth0603 commited on
Commit
820a438
·
verified ·
1 Parent(s): f1238c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -41,6 +41,11 @@ def _load_model():
41
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
42
 
43
 
 
 
 
 
 
44
  @app.get("/")
45
  def root():
46
  _load_model()
@@ -77,18 +82,30 @@ def predict(payload: PredictPayload):
77
  logits = outputs.logits # [1, num_labels]
78
  logits_list = logits[0].tolist()
79
  pred_idx = int(torch.argmax(logits, dim=-1).item())
 
 
 
 
 
80
  except Exception as e:
81
  return JSONResponse(status_code=500, content={"error": str(e)})
82
 
83
- cfg = getattr(_model, "config", None)
84
- id2label = getattr(cfg, "id2label", {}) if cfg else {}
85
- # Try to fetch label from config without any remapping logic
86
- pred_label = id2label.get(pred_idx, id2label.get(str(pred_idx), None))
87
 
 
 
 
 
 
 
88
  return {
89
- "logits": logits_list, # direct model output (authoritative)
90
- "predicted_index": pred_idx, # argmax over logits
91
- "predicted_label": pred_label, # from model config if available
 
 
92
  "id2label": id2label,
93
- "label2id": getattr(cfg, "label2id", {}) if cfg else {},
94
  }
 
41
  _ = _model(**_tokenizer(["warm up"], return_tensors="pt")).logits
42
 
43
 
44
+ def _id2label():
45
+ cfg = getattr(_model, "config", None)
46
+ return getattr(cfg, "id2label", {}) if cfg else {}
47
+
48
+
49
  @app.get("/")
50
  def root():
51
  _load_model()
 
82
  logits = outputs.logits # [1, num_labels]
83
  logits_list = logits[0].tolist()
84
  pred_idx = int(torch.argmax(logits, dim=-1).item())
85
+
86
+ # Keep client-compatible fields but also provide raw outputs
87
+ probs_t = torch.softmax(logits, dim=-1)[0]
88
+ score = float(probs_t[pred_idx])
89
+
90
  except Exception as e:
91
  return JSONResponse(status_code=500, content={"error": str(e)})
92
 
93
+ id2label = _id2label()
94
+ # Resolve label from model config (support int or str keys)
95
+ pred_label = id2label.get(pred_idx, id2label.get(str(pred_idx), str(pred_idx)))
 
96
 
97
+ # Build per-label probabilities for debugging/verification
98
+ probs = {}
99
+ for i, p in enumerate(probs_t.tolist()):
100
+ probs[id2label.get(i, id2label.get(str(i), str(i)))] = float(p)
101
+
102
+ # Backward-compatible keys: "label" and "score"
103
  return {
104
+ "label": pred_label, # expected by your client
105
+ "score": score, # probability of predicted class (softmax)
106
+ "predicted_index": pred_idx, # raw argmax index from logits
107
+ "logits": logits_list, # raw model output
108
+ "probs": probs, # per-label probabilities
109
  "id2label": id2label,
110
+ "label2id": getattr(getattr(_model, "config", None), "label2id", {}),
111
  }