Addax-Data-Science commited on
Commit
4512f1d
·
verified ·
1 Parent(s): 774aed6

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +22 -0
inference.py CHANGED
@@ -28,6 +28,7 @@ import pathlib
28
  import platform
29
  from pathlib import Path
30
 
 
31
  import torch
32
  import torch.nn as nn
33
  import torch.nn.functional as F
@@ -304,3 +305,24 @@ class ModelInference:
304
 
305
  except Exception as e:
306
  raise RuntimeError(f"Failed to extract class names: {e}") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  import platform
29
  from pathlib import Path
30
 
31
+ import numpy as np
32
  import torch
33
  import torch.nn as nn
34
  import torch.nn.functional as F
 
305
 
306
  except Exception as e:
307
  raise RuntimeError(f"Failed to extract class names: {e}") from e
308
+
309
+ def get_tensor(self, crop: Image.Image):
310
+ """Preprocess a crop into a numpy array for batch inference."""
311
+ tensor = self.preprocess(crop)
312
+ return tensor.numpy()
313
+
314
+ def classify_batch(self, batch):
315
+ """Run inference on a batch of preprocessed numpy arrays."""
316
+ tensor = torch.from_numpy(batch).to(self.device)
317
+ with torch.no_grad():
318
+ output = self.model(tensor)
319
+ probs = F.softmax(output, dim=1).cpu().numpy()
320
+
321
+ results = []
322
+ for p in probs:
323
+ classifications = [
324
+ [self.classes[i], float(p[i])]
325
+ for i in range(len(self.classes))
326
+ ]
327
+ results.append(classifications)
328
+ return results