dchen0 commited on
Commit
2846bb6
·
verified ·
1 Parent(s): 69baf1e

Add merged model + processor

Browse files
Files changed (1) hide show
  1. handler.py +14 -13
handler.py CHANGED
@@ -94,16 +94,17 @@ class EndpointHandler:
94
  pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
95
 
96
  with torch.no_grad():
97
- logits = self.model(pixel_values).logits
98
- probs = logits.softmax(dim=-1)[0]
99
-
100
- top_idx = int(probs.argmax())
101
- top_label = self.id2label[top_idx]
102
-
103
- return {
104
- "predicted_label": top_label,
105
- "scores": {
106
- self.id2label[i]: float(p)
107
- for i, p in enumerate(probs)
108
- }
109
- }
 
 
94
  pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
95
 
96
  with torch.no_grad():
97
+ logits = self.model(pixel_values).logits[0] # tensor [num_labels]
98
+ probs = logits.softmax(dim=-1)
99
+
100
+ # convert to the required wire format (top‑k or all classes)
101
+ k = min(5, probs.numel()) # send top‑5
102
+ topk = torch.topk(probs, k)
103
+
104
+ response = [
105
+ {"label": self.id2label[idx.item()], "score": prob.item()}
106
+ for prob, idx in zip(topk.values, topk.indices)
107
+ ]
108
+
109
+ return response # <‑‑ must be a *list* of dicts
110
+