Addax-Data-Science commited on
Commit
c9d39e2
·
verified ·
1 Parent(s): 4723644

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +252 -0
inference.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for TKM-ADS-v1 (Turkmenistan Species Classifier)
3
+
4
+ This model identifies 14 species or higher-level taxons present in Southern Turkmenistan.
5
+ Trained on ~1 million camera trap images achieving 95% validation accuracy, 93% precision,
6
+ and 94% recall. Note: Accuracy not tested on out-of-sample local dataset as local images
7
+ were not available.
8
+
9
+ Model: Turkmenistan v1
10
+ Input: 640x640 RGB images
11
+ Framework: PyTorch (YOLOv8 classification)
12
+ Classes: 14 species and taxonomic groups
13
+ Developer: Addax Data Science
14
+ Citation: https://joss.theoj.org/papers/10.21105/joss.05581
15
+ License: CC BY-NC-SA 4.0
16
+ Info: https://addaxdatascience.com/
17
+
18
+ Author: Peter van Lunteren
19
+ Created: 2026-01-14
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import pathlib
25
+ import platform
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ from PIL import Image, ImageFile, ImageOps
30
+ from ultralytics import YOLO
31
+
32
+ # Don't freak out over truncated images
33
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
34
+
35
+ # Make sure Windows-trained models work on Unix
36
+ plt = platform.system()
37
+ if plt != 'Windows':
38
+ pathlib.WindowsPath = pathlib.PosixPath
39
+
40
+
41
+ class ModelInference:
42
+ """YOLOv8 inference implementation for Turkmenistan species classifier."""
43
+
44
+ def __init__(self, model_dir: Path, model_path: Path):
45
+ """
46
+ Initialize with model paths.
47
+
48
+ Args:
49
+ model_dir: Directory containing model files
50
+ model_path: Path to tkm_v1.pt file
51
+ """
52
+ self.model_dir = model_dir
53
+ self.model_path = model_path
54
+ self.model: YOLO | None = None
55
+
56
+ def check_gpu(self) -> bool:
57
+ """
58
+ Check GPU availability for YOLOv8 inference.
59
+
60
+ Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
61
+
62
+ Returns:
63
+ True if GPU available, False otherwise
64
+ """
65
+ # Check Apple MPS (Apple Silicon)
66
+ try:
67
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
68
+ return True
69
+ except Exception:
70
+ pass
71
+
72
+ # Check CUDA (NVIDIA)
73
+ return torch.cuda.is_available()
74
+
75
+ def load_model(self) -> None:
76
+ """
77
+ Load YOLOv8 classification model into memory.
78
+
79
+ This function is called once during worker initialization.
80
+ The model is stored in self.model and reused for all subsequent
81
+ classification requests.
82
+
83
+ Raises:
84
+ RuntimeError: If model loading fails
85
+ FileNotFoundError: If model_path is invalid
86
+ """
87
+ if not self.model_path.exists():
88
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
89
+
90
+ try:
91
+ self.model = YOLO(str(self.model_path))
92
+ except Exception as e:
93
+ raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e
94
+
95
+ def get_crop(
96
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
97
+ ) -> Image.Image:
98
+ """
99
+ Crop image using model-specific preprocessing.
100
+
101
+ This cropping method was developed by Dan Morris for MegaDetector and is
102
+ designed to:
103
+ 1. Square the bounding box (max of width/height)
104
+ 2. Add padding to prevent over-enlargement of small animals
105
+ 3. Center the detection within the crop
106
+ 4. Pad with black (0) to maintain square aspect ratio
107
+
108
+ Args:
109
+ image: PIL Image (full resolution)
110
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
111
+
112
+ Returns:
113
+ Cropped and padded PIL Image ready for classification
114
+
115
+ Raises:
116
+ ValueError: If bbox is invalid (zero size)
117
+ """
118
+ img_w, img_h = image.size
119
+
120
+ # Denormalize bbox coordinates
121
+ xmin = int(bbox[0] * img_w)
122
+ ymin = int(bbox[1] * img_h)
123
+ box_w = int(bbox[2] * img_w)
124
+ box_h = int(bbox[3] * img_h)
125
+
126
+ # Square the box (use max dimension)
127
+ box_size = max(box_w, box_h)
128
+
129
+ # Add padding (prevents over-enlargement of small animals)
130
+ box_size = self._pad_crop(box_size)
131
+
132
+ # Center the detection within the squared crop
133
+ xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
134
+ ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
135
+
136
+ # Clip to image boundaries
137
+ box_w = min(img_w, box_size)
138
+ box_h = min(img_h, box_size)
139
+
140
+ if box_w == 0 or box_h == 0:
141
+ raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
142
+
143
+ # Crop and pad to square
144
+ crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
145
+ crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
146
+
147
+ return crop
148
+
149
+ def _pad_crop(self, box_size: int) -> int:
150
+ """
151
+ Calculate padded crop size to prevent over-enlargement of small animals.
152
+
153
+ YOLOv8 expects 224x224 input. This function ensures small detections aren't
154
+ excessively upscaled while adding consistent padding to larger detections.
155
+
156
+ Args:
157
+ box_size: Original bounding box size (max of width/height)
158
+
159
+ Returns:
160
+ Padded box size
161
+ """
162
+ input_size_network = 224
163
+ default_padding = 30
164
+
165
+ if box_size >= input_size_network:
166
+ # Large detection: add default padding
167
+ return box_size + default_padding
168
+ else:
169
+ # Small detection: ensure minimum size without excessive enlargement
170
+ diff_size = input_size_network - box_size
171
+ if diff_size < default_padding:
172
+ return box_size + default_padding
173
+ else:
174
+ return input_size_network
175
+
176
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
177
+ """
178
+ Run YOLOv8 classification on cropped image.
179
+
180
+ Args:
181
+ crop: Cropped and preprocessed PIL Image
182
+
183
+ Returns:
184
+ List of [class_name, confidence] lists for ALL classes, in model order.
185
+ Example: [["goitered gazelle", 0.92], ["urial", 0.05], ["wolf", 0.02], ...]
186
+ NOTE: Sorting by confidence is handled by classification_worker.py
187
+
188
+ Raises:
189
+ RuntimeError: If model not loaded or inference fails
190
+ """
191
+ if self.model is None:
192
+ raise RuntimeError("Model not loaded - call load_model() first")
193
+
194
+ try:
195
+ # Run YOLOv8 classification (verbose=False suppresses progress bar)
196
+ results = self.model(crop, verbose=False)
197
+
198
+ # Extract class names dict (YOLOv8 uses alphabetical order)
199
+ # Example: {0: "bird", 1: "goitered gazelle", ..., 13: "wolf"}
200
+ names_dict = results[0].names
201
+
202
+ # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
203
+ probs = results[0].probs.data.tolist()
204
+
205
+ # Build list of [class_name, confidence] pairs (as lists, not tuples!)
206
+ # Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
207
+ classifications = []
208
+ for idx, class_name in names_dict.items():
209
+ confidence = probs[idx]
210
+ classifications.append([class_name, confidence])
211
+
212
+ # NOTE: Sorting by confidence is handled by classification_worker.py
213
+ # Model developers don't need to sort - just return all class predictions
214
+ return classifications
215
+
216
+ except Exception as e:
217
+ raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
218
+
219
+ def get_class_names(self) -> dict[str, str]:
220
+ """
221
+ Get mapping of class IDs to species names from YOLOv8 model.
222
+
223
+ YOLOv8 stores class names in alphabetical order internally. This function
224
+ extracts those names and creates a 1-indexed mapping for the JSON format.
225
+
226
+ NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
227
+ The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
228
+
229
+ Returns:
230
+ Dict mapping class ID (1-indexed string) to common name
231
+ Example: {"1": "bird", "2": "goitered gazelle", ..., "14": "wolf"}
232
+
233
+ Raises:
234
+ RuntimeError: If model not loaded
235
+ """
236
+ if self.model is None:
237
+ raise RuntimeError("Model not loaded - call load_model() first")
238
+
239
+ try:
240
+ # YOLOv8 names dict (alphabetical order): {0: "bird", 1: "goitered gazelle", ...}
241
+ yolo_names = self.model.names
242
+
243
+ # Convert to 1-indexed dict for JSON compatibility
244
+ class_names = {}
245
+ for idx, name in yolo_names.items():
246
+ class_id_str = str(idx + 1) # 1-indexed
247
+ class_names[class_id_str] = name
248
+
249
+ return class_names
250
+
251
+ except Exception as e:
252
+ raise RuntimeError(f"Failed to extract class names from model: {e}") from e