Addax-Data-Science commited on
Commit
100c0ce
·
verified ·
1 Parent(s): f1e39ec

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +22 -0
inference.py CHANGED
@@ -24,6 +24,7 @@ import pathlib
24
  import platform
25
  from pathlib import Path
26
 
 
27
  import pandas as pd
28
  import torch
29
  import torch.nn as nn
@@ -311,3 +312,24 @@ class ModelInference:
311
 
312
  except Exception as e:
313
  raise RuntimeError(f"Failed to extract class names: {e}") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  import platform
25
  from pathlib import Path
26
 
27
+ import numpy as np
28
  import pandas as pd
29
  import torch
30
  import torch.nn as nn
 
312
 
313
  except Exception as e:
314
  raise RuntimeError(f"Failed to extract class names: {e}") from e
315
+
316
+ def get_tensor(self, crop: Image.Image):
317
+ """Preprocess a crop into a numpy array for batch inference."""
318
+ tensor = self.preprocess(crop)
319
+ return tensor.numpy()
320
+
321
+ def classify_batch(self, batch):
322
+ """Run inference on a batch of preprocessed numpy arrays."""
323
+ tensor = torch.from_numpy(batch).to(self.device)
324
+ with torch.no_grad():
325
+ output = self.model(tensor)
326
+ probs = F.softmax(output, dim=1).cpu().numpy()
327
+
328
+ results = []
329
+ for p in probs:
330
+ classifications = []
331
+ for i in range(len(p)):
332
+ pred_class = self.classes.iloc[i].values[1] # 'Code' column
333
+ classifications.append([pred_class, float(p[i])])
334
+ results.append(classifications)
335
+ return results