Upload inference.py
Browse files- inference.py +29 -0
inference.py
CHANGED
|
@@ -215,3 +215,32 @@ class ModelInference:
|
|
| 215 |
return {
|
| 216 |
str(i + 1): name for i, name in enumerate(self.class_names)
|
| 217 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
return {
|
| 216 |
str(i + 1): name for i, name in enumerate(self.class_names)
|
| 217 |
}
|
| 218 |
+
|
| 219 |
+
def get_tensor(self, crop: Image.Image):
|
| 220 |
+
"""Preprocess a crop into a numpy array for batch inference."""
|
| 221 |
+
if crop.mode != "RGB":
|
| 222 |
+
crop = crop.convert("RGB")
|
| 223 |
+
|
| 224 |
+
img_tensor = TF.pil_to_tensor(crop)
|
| 225 |
+
img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
|
| 226 |
+
img_tensor = TF.resize(
|
| 227 |
+
img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
|
| 228 |
+
)
|
| 229 |
+
img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
|
| 230 |
+
return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
|
| 231 |
+
|
| 232 |
+
def classify_batch(self, batch):
|
| 233 |
+
"""Run inference on a batch of preprocessed numpy arrays."""
|
| 234 |
+
tensor = torch.from_numpy(batch).to(self.device)
|
| 235 |
+
with torch.no_grad():
|
| 236 |
+
logits = self.model(tensor)
|
| 237 |
+
probs = F.softmax(logits, dim=1).cpu().numpy()
|
| 238 |
+
|
| 239 |
+
results = []
|
| 240 |
+
for p in probs:
|
| 241 |
+
classifications = [
|
| 242 |
+
[self.class_names[i], float(p[i])]
|
| 243 |
+
for i in range(len(self.class_names))
|
| 244 |
+
]
|
| 245 |
+
results.append(classifications)
|
| 246 |
+
return results
|