Addax-Data-Science commited on
Commit
8b37e79
·
verified ·
1 Parent(s): 44cc414

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +22 -0
inference.py CHANGED
@@ -20,6 +20,7 @@ import pickle
20
  import platform
21
  import pathlib
22
 
 
23
  import torch
24
  from torchvision import transforms
25
  from PIL import Image, ImageFile
@@ -186,3 +187,24 @@ class ModelInference:
186
  class_names[class_id_str] = class_label
187
 
188
  return class_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import platform
21
  import pathlib
22
 
23
+ import numpy as np
24
  import torch
25
  from torchvision import transforms
26
  from PIL import Image, ImageFile
 
187
  class_names[class_id_str] = class_label
188
 
189
  return class_names
190
+
191
+ def get_tensor(self, crop: Image.Image):
192
+ """Preprocess a crop into a numpy array for batch inference."""
193
+ tensor = self.transform(crop)
194
+ return tensor.numpy()
195
+
196
+ def classify_batch(self, batch):
197
+ """Run inference on a batch of preprocessed numpy arrays."""
198
+ tensor = torch.from_numpy(batch).to(self.device)
199
+ with torch.no_grad():
200
+ output = self.model(tensor)
201
+ probs = torch.nn.functional.softmax(output, dim=1).cpu().numpy()
202
+
203
+ results = []
204
+ for p in probs:
205
+ classifications = [
206
+ [self.class_labels[i], float(p[i])]
207
+ for i in range(len(self.class_labels))
208
+ ]
209
+ results.append(classifications)
210
+ return results