Addax-Data-Science commited on
Commit
774aed6
·
verified ·
1 Parent(s): 6eeb8a2

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +306 -0
inference.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for SAH-DRY-ADS-v1 (Sub-Saharan Drylands Species Classifier)
3
+
4
+ This model classifies 328 categories across eastern and southern African ecosystems,
5
+ with taxonomic fallback for uncertain species-level predictions. Trained on 2.8+ million
6
+ camera trap images from savannas, dry forests, arid shrublands, and semi-desert habitats
7
+ across 9 countries. All training data is open-source via LILA BC (https://lila.science/).
8
+
9
+ Model: Sub-Saharan Drylands v1
10
+ Input: Variable size (extracted from checkpoint, typically 480x480)
11
+ Framework: PyTorch (EfficientNet V2 Medium architecture)
12
+ Classes: 328 species and higher-level taxa with taxonomic fallback
13
+ Developer: Addax Data Science
14
+ Citation: https://joss.theoj.org/papers/10.21105/joss.05581
15
+ License: CC BY-NC-SA 4.0
16
+ Info: https://addaxdatascience.com/
17
+
18
+ Training regions: South Africa, Tanzania, Kenya, Mozambique, Botswana, Namibia,
19
+ Rwanda, Madagascar, Uganda
20
+
21
+ Author: Peter van Lunteren
22
+ Created: 2026-01-14
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import pathlib
28
+ import platform
29
+ from pathlib import Path
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from PIL import Image, ImageFile, ImageOps
35
+ from torchvision import transforms
36
+ from torchvision.models import efficientnet
37
+
38
+ # Don't freak out over truncated images
39
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
40
+
41
+ # Make sure Windows-trained models work on Unix
42
+ plt = platform.system()
43
+ if plt != 'Windows':
44
+ pathlib.WindowsPath = pathlib.PosixPath
45
+
46
+
47
+ class EfficientNetV2M(nn.Module):
48
+ """EfficientNet V2 Medium architecture for wildlife classification."""
49
+
50
+ def __init__(self, num_classes: int, tune: bool = True):
51
+ super(EfficientNetV2M, self).__init__()
52
+ self.avgpool = nn.AdaptiveAvgPool2d(1)
53
+ self.model = efficientnet.efficientnet_v2_m(
54
+ weights=efficientnet.EfficientNet_V2_M_Weights.DEFAULT
55
+ )
56
+ if tune:
57
+ for params in self.model.parameters():
58
+ params.requires_grad = True
59
+ num_ftrs = self.model.classifier[1].in_features
60
+ self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes)
61
+
62
+ def forward(self, x):
63
+ x = self.model.features(x)
64
+ x = self.avgpool(x)
65
+ x = torch.flatten(x, 1)
66
+ prediction = self.model.classifier(x)
67
+ return prediction
68
+
69
+ class ModelInference:
70
+ """PyTorch inference implementation for Sub-Saharan Drylands species classifier."""
71
+
72
+ def __init__(self, model_dir: Path, model_path: Path):
73
+ """
74
+ Initialize with model paths.
75
+
76
+ Args:
77
+ model_dir: Directory containing model files
78
+ model_path: Path to sub_saharan_drylands_v1.pt checkpoint file
79
+ """
80
+ self.model_dir = model_dir
81
+ self.model_path = model_path
82
+ self.model = None
83
+ self.device = None
84
+ self.image_size = None
85
+ self.classes = []
86
+ self.preprocess = None
87
+
88
+ def check_gpu(self) -> bool:
89
+ """
90
+ Check GPU availability for PyTorch inference.
91
+
92
+ Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
93
+
94
+ Returns:
95
+ True if GPU available, False otherwise
96
+ """
97
+ # Check Apple MPS (Apple Silicon)
98
+ try:
99
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
100
+ return True
101
+ except Exception:
102
+ pass
103
+
104
+ # Check CUDA (NVIDIA)
105
+ return torch.cuda.is_available()
106
+
107
+ def load_model(self, device_str: str = 'cpu') -> None:
108
+ """
109
+ Load PyTorch model from checkpoint.
110
+
111
+ The checkpoint contains:
112
+ - model: State dict with trained weights
113
+ - categories: Dict mapping class names to indices
114
+ - image_size: Tuple with input dimensions
115
+
116
+ Args:
117
+ device_str: Device to load model on ('cpu', 'cuda', or 'mps')
118
+
119
+ Raises:
120
+ RuntimeError: If model loading fails
121
+ FileNotFoundError: If model_path is invalid
122
+ """
123
+ if not self.model_path.exists():
124
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
125
+
126
+ try:
127
+ # Set device
128
+ self.device = torch.device(device_str)
129
+
130
+ # Load checkpoint
131
+ checkpoint = torch.load(str(self.model_path), map_location=self.device)
132
+
133
+ # Extract metadata
134
+ self.image_size = tuple(checkpoint['image_size'])
135
+ categories = checkpoint['categories']
136
+ self.classes = list(categories.keys())
137
+
138
+ # Initialize EfficientNet V2 Medium architecture
139
+ num_classes = len(self.classes)
140
+ self.model = EfficientNetV2M(num_classes, tune=False)
141
+
142
+ # Load weights
143
+ self.model.load_state_dict(checkpoint['model'])
144
+ self.model.to(self.device)
145
+ self.model.eval()
146
+
147
+ # Setup preprocessing
148
+ self.preprocess = transforms.Compose([
149
+ transforms.Resize(self.image_size),
150
+ transforms.ToTensor(),
151
+ ])
152
+
153
+ except Exception as e:
154
+ raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e
155
+
156
+ def get_crop(
157
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
158
+ ) -> Image.Image:
159
+ """
160
+ Crop image using model-specific preprocessing.
161
+
162
+ This cropping method was developed by Dan Morris for MegaDetector and is
163
+ designed to:
164
+ 1. Square the bounding box (max of width/height)
165
+ 2. Add padding to prevent over-enlargement of small animals
166
+ 3. Center the detection within the crop
167
+ 4. Pad with black (0) to maintain square aspect ratio
168
+
169
+ Args:
170
+ image: PIL Image (full resolution)
171
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
172
+
173
+ Returns:
174
+ Cropped and padded PIL Image ready for classification
175
+
176
+ Raises:
177
+ ValueError: If bbox is invalid (zero size)
178
+ """
179
+ img_w, img_h = image.size
180
+
181
+ # Denormalize bbox coordinates
182
+ xmin = int(bbox[0] * img_w)
183
+ ymin = int(bbox[1] * img_h)
184
+ box_w = int(bbox[2] * img_w)
185
+ box_h = int(bbox[3] * img_h)
186
+
187
+ # Square the box (use max dimension)
188
+ box_size = max(box_w, box_h)
189
+
190
+ # Add padding (prevents over-enlargement of small animals)
191
+ box_size = self._pad_crop(box_size)
192
+
193
+ # Center the detection within the squared crop
194
+ xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
195
+ ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
196
+
197
+ # Clip to image boundaries
198
+ box_w = min(img_w, box_size)
199
+ box_h = min(img_h, box_size)
200
+
201
+ if box_w == 0 or box_h == 0:
202
+ raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
203
+
204
+ # Crop and pad to square
205
+ crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
206
+ crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
207
+
208
+ return crop
209
+
210
+ def _pad_crop(self, box_size: int) -> int:
211
+ """
212
+ Calculate padded crop size to prevent over-enlargement of small animals.
213
+
214
+ Standard network input is 224x224. This function ensures small detections
215
+ aren't excessively upscaled while adding consistent padding to larger detections.
216
+
217
+ Args:
218
+ box_size: Original bounding box size (max of width/height)
219
+
220
+ Returns:
221
+ Padded box size
222
+ """
223
+ input_size_network = 224
224
+ default_padding = 30
225
+
226
+ if box_size >= input_size_network:
227
+ # Large detection: add default padding
228
+ return box_size + default_padding
229
+ else:
230
+ # Small detection: ensure minimum size without excessive enlargement
231
+ diff_size = input_size_network - box_size
232
+ if diff_size < default_padding:
233
+ return box_size + default_padding
234
+ else:
235
+ return input_size_network
236
+
237
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
238
+ """
239
+ Run PyTorch classification on cropped image.
240
+
241
+ Args:
242
+ crop: Cropped and preprocessed PIL Image
243
+
244
+ Returns:
245
+ List of [class_name, confidence] lists for ALL classes, in model order.
246
+ Example: [["lion", 0.85], ["leopard", 0.10], ["cheetah", 0.02], ...]
247
+ NOTE: Sorting by confidence is handled by classification_worker.py
248
+
249
+ Raises:
250
+ RuntimeError: If model not loaded or inference fails
251
+ """
252
+ if self.model is None:
253
+ raise RuntimeError("Model not loaded - call load_model() first")
254
+
255
+ try:
256
+ # Preprocess image (resize and convert to tensor)
257
+ input_tensor = self.preprocess(crop)
258
+ input_batch = input_tensor.unsqueeze(0) # Add batch dimension
259
+ input_batch = input_batch.to(self.device)
260
+
261
+ # Run inference
262
+ with torch.no_grad():
263
+ output = self.model(input_batch)
264
+
265
+ # Apply softmax to get probabilities
266
+ probabilities = F.softmax(output, dim=1)
267
+ probabilities_np = probabilities.cpu().detach().numpy()
268
+ confidence_scores = probabilities_np[0]
269
+
270
+ # Build list of [class_name, confidence] pairs
271
+ classifications = []
272
+ for i in range(len(confidence_scores)):
273
+ pred_class = self.classes[i]
274
+ pred_conf = float(confidence_scores[i])
275
+ classifications.append([pred_class, pred_conf])
276
+
277
+ return classifications
278
+
279
+ except Exception as e:
280
+ raise RuntimeError(f"PyTorch classification failed: {e}") from e
281
+
282
+ def get_class_names(self) -> dict[str, str]:
283
+ """
284
+ Get mapping of class IDs to species names.
285
+
286
+ Returns:
287
+ Dict mapping class ID (1-indexed string) to species/taxon name
288
+ Example: {"1": "aardvark", "2": "african wild cat", ..., "328": "zebra"}
289
+
290
+ Raises:
291
+ RuntimeError: If model not loaded
292
+ """
293
+ if self.model is None:
294
+ raise RuntimeError("Model not loaded - call load_model() first")
295
+
296
+ try:
297
+ # Create 1-indexed mapping of class IDs to names
298
+ class_names = {}
299
+ for i, class_name in enumerate(self.classes):
300
+ class_id_str = str(i + 1) # 1-indexed
301
+ class_names[class_id_str] = class_name
302
+
303
+ return class_names
304
+
305
+ except Exception as e:
306
+ raise RuntimeError(f"Failed to extract class names: {e}") from e