Upload inference.py
Browse files- inference.py +18 -0
inference.py
CHANGED
|
@@ -230,3 +230,21 @@ class ModelInference:
|
|
| 230 |
|
| 231 |
except Exception as e:
|
| 232 |
raise RuntimeError(f"Failed to extract class names: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
except Exception as e:
|
| 232 |
raise RuntimeError(f"Failed to extract class names: {e}") from e
|
| 233 |
+
|
| 234 |
+
def get_tensor(self, crop: Image.Image):
|
| 235 |
+
"""Preprocess a crop into a numpy array for batch inference."""
|
| 236 |
+
img = np.array(crop)
|
| 237 |
+
img = cv2.resize(img, (self.img_size, self.img_size))
|
| 238 |
+
return img
|
| 239 |
+
|
| 240 |
+
def classify_batch(self, batch):
|
| 241 |
+
"""Run inference on a batch of preprocessed numpy arrays."""
|
| 242 |
+
probs = self.model.predict(batch, verbose=0)
|
| 243 |
+
results = []
|
| 244 |
+
for p in probs:
|
| 245 |
+
classifications = [
|
| 246 |
+
[self.class_ids_sorted[i], float(p[i])]
|
| 247 |
+
for i in range(len(self.class_ids_sorted))
|
| 248 |
+
]
|
| 249 |
+
results.append(classifications)
|
| 250 |
+
return results
|