Addax-Data-Science commited on
Commit
93eeaa9
·
verified ·
1 Parent(s): d1caf07

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +29 -0
inference.py CHANGED
@@ -215,3 +215,32 @@ class ModelInference:
215
  return {
216
  str(i + 1): name for i, name in enumerate(self.class_names)
217
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  return {
216
  str(i + 1): name for i, name in enumerate(self.class_names)
217
  }
218
+
219
+ def get_tensor(self, crop: Image.Image):
220
+ """Preprocess a crop into a numpy array for batch inference."""
221
+ if crop.mode != "RGB":
222
+ crop = crop.convert("RGB")
223
+
224
+ img_tensor = TF.pil_to_tensor(crop)
225
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
226
+ img_tensor = TF.resize(
227
+ img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
228
+ )
229
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
230
+ return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
231
+
232
+ def classify_batch(self, batch):
233
+ """Run inference on a batch of preprocessed numpy arrays."""
234
+ tensor = torch.from_numpy(batch).to(self.device)
235
+ with torch.no_grad():
236
+ logits = self.model(tensor)
237
+ probs = F.softmax(logits, dim=1).cpu().numpy()
238
+
239
+ results = []
240
+ for p in probs:
241
+ classifications = [
242
+ [self.class_names[i], float(p[i])]
243
+ for i in range(len(self.class_names))
244
+ ]
245
+ results.append(classifications)
246
+ return results