Add merged model + processor
Browse files- handler.py +14 -13
handler.py
CHANGED
|
@@ -94,16 +94,17 @@ class EndpointHandler:
|
|
| 94 |
pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
|
| 95 |
|
| 96 |
with torch.no_grad():
|
| 97 |
-
logits = self.model(pixel_values).logits
|
| 98 |
-
probs = logits.softmax(dim=-1)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
"
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
| 94 |
pixel_values = self.transform(image).unsqueeze(0) # [1, C, H, W]
|
| 95 |
|
| 96 |
with torch.no_grad():
|
| 97 |
+
logits = self.model(pixel_values).logits[0] # tensor [num_labels]
|
| 98 |
+
probs = logits.softmax(dim=-1)
|
| 99 |
+
|
| 100 |
+
# convert to the required wire format (top‑k or all classes)
|
| 101 |
+
k = min(5, probs.numel()) # send top‑5
|
| 102 |
+
topk = torch.topk(probs, k)
|
| 103 |
+
|
| 104 |
+
response = [
|
| 105 |
+
{"label": self.id2label[idx.item()], "score": prob.item()}
|
| 106 |
+
for prob, idx in zip(topk.values, topk.indices)
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
return response # <‑‑ must be a *list* of dicts
|
| 110 |
+
|