Upload inference.py
Browse files- inference.py +18 -0
inference.py
CHANGED
|
@@ -264,3 +264,21 @@ class ModelInference:
|
|
| 264 |
|
| 265 |
except Exception as e:
|
| 266 |
raise RuntimeError(f"Failed to extract class names: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
except Exception as e:
|
| 266 |
raise RuntimeError(f"Failed to extract class names: {e}") from e
|
| 267 |
+
|
| 268 |
+
def get_tensor(self, crop: Image.Image):
|
| 269 |
+
"""Preprocess a crop into a numpy array for batch inference."""
|
| 270 |
+
img = np.array(crop)
|
| 271 |
+
img = cv2.resize(img, (self.img_size, self.img_size))
|
| 272 |
+
return img
|
| 273 |
+
|
| 274 |
+
def classify_batch(self, batch):
|
| 275 |
+
"""Run inference on a batch of preprocessed numpy arrays."""
|
| 276 |
+
probs = self.model.predict(batch, verbose=0)
|
| 277 |
+
results = []
|
| 278 |
+
for p in probs:
|
| 279 |
+
classifications = [
|
| 280 |
+
[self.class_ids[i], float(p[i])]
|
| 281 |
+
for i in range(len(self.class_ids))
|
| 282 |
+
]
|
| 283 |
+
results.append(classifications)
|
| 284 |
+
return results
|