Addax-Data-Science commited on
Commit
aba2956
·
verified ·
1 Parent(s): d7437ff

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +211 -219
inference.py CHANGED
@@ -12,6 +12,7 @@ Training data: 850,000+ images
12
 
13
  Original source: streamlit-AddaxAI/classification/model_types/addax-yolov8/classify_detections.py
14
  Adapted by: Claude Code on 2026-01-11
 
15
  """
16
 
17
  from __future__ import annotations
@@ -24,13 +25,6 @@ import torch
24
  from PIL import Image, ImageFile, ImageOps
25
  from ultralytics import YOLO
26
 
27
- # Module-level variables (injected by AddaxAI framework)
28
- MODEL_DIR: Path | None = None # Set by CustomInferenceLoader
29
- MODEL_PATH: Path | None = None # Set by CustomInferenceLoader
30
-
31
- # Module-level model instance (loaded once at startup)
32
- animal_model: YOLO | None = None
33
-
34
  # Don't freak out over truncated images
35
  ImageFile.LOAD_TRUNCATED_IMAGES = True
36
 
@@ -40,217 +34,215 @@ if plt != 'Windows':
40
  pathlib.WindowsPath = pathlib.PosixPath
41
 
42
 
43
- def check_gpu() -> bool:
44
- """
45
- Check GPU availability for YOLOv8 inference.
46
-
47
- Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
48
-
49
- Returns:
50
- True if GPU available, False otherwise
51
- """
52
- # Check Apple MPS (Apple Silicon)
53
- try:
54
- if torch.backends.mps.is_built() and torch.backends.mps.is_available():
55
- return True
56
- except Exception:
57
- pass
58
-
59
- # Check CUDA (NVIDIA)
60
- return torch.cuda.is_available()
61
-
62
-
63
- def load_model() -> None:
64
- """
65
- Load YOLOv8 classification model into memory.
66
-
67
- This function is called once during worker initialization.
68
- The model is stored in the global `animal_model` variable and reused
69
- for all subsequent classification requests.
70
-
71
- Raises:
72
- RuntimeError: If model loading fails
73
- FileNotFoundError: If MODEL_PATH is invalid
74
- """
75
- global animal_model, MODEL_PATH
76
-
77
- # MODEL_PATH is injected by framework before this function is called
78
- # Check that the path exists (framework guarantees it's not None)
79
- if not MODEL_PATH.exists():
80
- raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
81
-
82
- try:
83
- animal_model = YOLO(str(MODEL_PATH))
84
- except Exception as e:
85
- raise RuntimeError(f"Failed to load YOLOv8 model from {MODEL_PATH}: {e}") from e
86
-
87
-
88
- def get_crop(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image:
89
- """
90
- Crop image using model-specific preprocessing.
91
-
92
- This cropping method was developed by Dan Morris for MegaDetector and is
93
- designed to:
94
- 1. Square the bounding box (max of width/height)
95
- 2. Add padding to prevent over-enlargement of small animals
96
- 3. Center the detection within the crop
97
- 4. Pad with black (0) to maintain square aspect ratio
98
-
99
- Args:
100
- image: PIL Image (full resolution)
101
- bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
102
-
103
- Returns:
104
- Cropped and padded PIL Image ready for classification
105
-
106
- Raises:
107
- ValueError: If bbox is invalid (zero size)
108
- """
109
- img_w, img_h = image.size
110
-
111
- # Denormalize bbox coordinates
112
- xmin = int(bbox[0] * img_w)
113
- ymin = int(bbox[1] * img_h)
114
- box_w = int(bbox[2] * img_w)
115
- box_h = int(bbox[3] * img_h)
116
-
117
- # Square the box (use max dimension)
118
- box_size = max(box_w, box_h)
119
-
120
- # Add padding (prevents over-enlargement of small animals)
121
- box_size = _pad_crop(box_size)
122
-
123
- # Center the detection within the squared crop
124
- xmin = max(0, min(
125
- xmin - int((box_size - box_w) / 2),
126
- img_w - box_w
127
- ))
128
- ymin = max(0, min(
129
- ymin - int((box_size - box_h) / 2),
130
- img_h - box_h
131
- ))
132
-
133
- # Clip to image boundaries
134
- box_w = min(img_w, box_size)
135
- box_h = min(img_h, box_size)
136
-
137
- if box_w == 0 or box_h == 0:
138
- raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
139
-
140
- # Crop and pad to square
141
- crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
142
- crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
143
-
144
- return crop
145
-
146
-
147
- def _pad_crop(box_size: int) -> int:
148
- """
149
- Calculate padded crop size to prevent over-enlargement of small animals.
150
-
151
- YOLOv8 expects 224x224 input. This function ensures small detections aren't
152
- excessively upscaled while adding consistent padding to larger detections.
153
-
154
- Args:
155
- box_size: Original bounding box size (max of width/height)
156
-
157
- Returns:
158
- Padded box size
159
- """
160
- input_size_network = 224
161
- default_padding = 30
162
-
163
- if box_size >= input_size_network:
164
- # Large detection: add default padding
165
- return box_size + default_padding
166
- else:
167
- # Small detection: ensure minimum size without excessive enlargement
168
- diff_size = input_size_network - box_size
169
- if diff_size < default_padding:
170
  return box_size + default_padding
171
  else:
172
- return input_size_network
173
-
174
-
175
- def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
176
- """
177
- Run YOLOv8 classification on cropped image.
178
-
179
- Args:
180
- crop: Cropped and preprocessed PIL Image
181
-
182
- Returns:
183
- List of (class_name, confidence) tuples for ALL classes, sorted by confidence.
184
- Example: [("giraffe", 0.99985), ("cattle", 0.00003), ...]
185
-
186
- Raises:
187
- RuntimeError: If model not loaded or inference fails
188
- """
189
- global animal_model
190
-
191
- if animal_model is None:
192
- raise RuntimeError("Model not loaded - call load_model() first")
193
-
194
- try:
195
- # Run YOLOv8 classification (verbose=False suppresses progress bar)
196
- results = animal_model(crop, verbose=False)
197
-
198
- # Extract class names dict (YOLOv8 uses alphabetical order)
199
- # Example: {0: "aardwolf", 1: "african wild cat", ..., 13: "giraffe", ...}
200
- names_dict = results[0].names
201
-
202
- # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
203
- probs = results[0].probs.data.tolist()
204
-
205
- # Build list of (class_name, confidence) tuples
206
- # Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
207
- classifications = []
208
- for idx, class_name in names_dict.items():
209
- confidence = probs[idx]
210
- classifications.append((class_name, confidence))
211
-
212
- # Sort by confidence descending (already sorted by YOLOv8, but ensure it)
213
- classifications.sort(key=lambda x: x[1], reverse=True)
214
-
215
- return classifications
216
-
217
- except Exception as e:
218
- raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
219
-
220
-
221
- def get_class_names() -> dict[str, str]:
222
- """
223
- Get mapping of class IDs to species names from YOLOv8 model.
224
-
225
- YOLOv8 stores class names in alphabetical order internally. This function
226
- extracts those names and creates a 1-indexed mapping for the JSON format.
227
-
228
- NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
229
- The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
230
-
231
- Returns:
232
- Dict mapping class ID (1-indexed string) to common name
233
- Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...}
234
-
235
- Raises:
236
- RuntimeError: If model not loaded
237
- """
238
- global animal_model
239
-
240
- if animal_model is None:
241
- raise RuntimeError("Model not loaded - call load_model() first")
242
-
243
- try:
244
- # YOLOv8 names dict (alphabetical order): {0: "aardwolf", 1: "african wild cat", ...}
245
- yolo_names = animal_model.names
246
-
247
- # Convert to 1-indexed dict for JSON compatibility
248
- class_names = {}
249
- for idx, name in yolo_names.items():
250
- class_id_str = str(idx + 1) # 1-indexed
251
- class_names[class_id_str] = name
252
-
253
- return class_names
254
-
255
- except Exception as e:
256
- raise RuntimeError(f"Failed to extract class names from model: {e}") from e
 
12
 
13
  Original source: streamlit-AddaxAI/classification/model_types/addax-yolov8/classify_detections.py
14
  Adapted by: Claude Code on 2026-01-11
15
+ Updated: 2026-01-13 - Migrated to class-based interface
16
  """
17
 
18
  from __future__ import annotations
 
25
  from PIL import Image, ImageFile, ImageOps
26
  from ultralytics import YOLO
27
 
 
 
 
 
 
 
 
28
  # Don't freak out over truncated images
29
  ImageFile.LOAD_TRUNCATED_IMAGES = True
30
 
 
34
  pathlib.WindowsPath = pathlib.PosixPath
35
 
36
 
37
+ class ModelInference:
38
+ """YOLOv8 inference implementation for Namibian Desert species classifier."""
39
+
40
+ def __init__(self, model_dir: Path, model_path: Path):
41
+ """
42
+ Initialize with model paths.
43
+
44
+ Args:
45
+ model_dir: Directory containing model files
46
+ model_path: Path to namib_desert_v1.pt file
47
+ """
48
+ self.model_dir = model_dir
49
+ self.model_path = model_path
50
+ self.model: YOLO | None = None
51
+
52
+ def check_gpu(self) -> bool:
53
+ """
54
+ Check GPU availability for YOLOv8 inference.
55
+
56
+ Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
57
+
58
+ Returns:
59
+ True if GPU available, False otherwise
60
+ """
61
+ # Check Apple MPS (Apple Silicon)
62
+ try:
63
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
64
+ return True
65
+ except Exception:
66
+ pass
67
+
68
+ # Check CUDA (NVIDIA)
69
+ return torch.cuda.is_available()
70
+
71
+ def load_model(self) -> None:
72
+ """
73
+ Load YOLOv8 classification model into memory.
74
+
75
+ This function is called once during worker initialization.
76
+ The model is stored in self.model and reused for all subsequent
77
+ classification requests.
78
+
79
+ Raises:
80
+ RuntimeError: If model loading fails
81
+ FileNotFoundError: If model_path is invalid
82
+ """
83
+ if not self.model_path.exists():
84
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
85
+
86
+ try:
87
+ self.model = YOLO(str(self.model_path))
88
+ except Exception as e:
89
+ raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e
90
+
91
+ def get_crop(
92
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
93
+ ) -> Image.Image:
94
+ """
95
+ Crop image using model-specific preprocessing.
96
+
97
+ This cropping method was developed by Dan Morris for MegaDetector and is
98
+ designed to:
99
+ 1. Square the bounding box (max of width/height)
100
+ 2. Add padding to prevent over-enlargement of small animals
101
+ 3. Center the detection within the crop
102
+ 4. Pad with black (0) to maintain square aspect ratio
103
+
104
+ Args:
105
+ image: PIL Image (full resolution)
106
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
107
+
108
+ Returns:
109
+ Cropped and padded PIL Image ready for classification
110
+
111
+ Raises:
112
+ ValueError: If bbox is invalid (zero size)
113
+ """
114
+ img_w, img_h = image.size
115
+
116
+ # Denormalize bbox coordinates
117
+ xmin = int(bbox[0] * img_w)
118
+ ymin = int(bbox[1] * img_h)
119
+ box_w = int(bbox[2] * img_w)
120
+ box_h = int(bbox[3] * img_h)
121
+
122
+ # Square the box (use max dimension)
123
+ box_size = max(box_w, box_h)
124
+
125
+ # Add padding (prevents over-enlargement of small animals)
126
+ box_size = self._pad_crop(box_size)
127
+
128
+ # Center the detection within the squared crop
129
+ xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
130
+ ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
131
+
132
+ # Clip to image boundaries
133
+ box_w = min(img_w, box_size)
134
+ box_h = min(img_h, box_size)
135
+
136
+ if box_w == 0 or box_h == 0:
137
+ raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
138
+
139
+ # Crop and pad to square
140
+ crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
141
+ crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
142
+
143
+ return crop
144
+
145
+ def _pad_crop(self, box_size: int) -> int:
146
+ """
147
+ Calculate padded crop size to prevent over-enlargement of small animals.
148
+
149
+ YOLOv8 expects 224x224 input. This function ensures small detections aren't
150
+ excessively upscaled while adding consistent padding to larger detections.
151
+
152
+ Args:
153
+ box_size: Original bounding box size (max of width/height)
154
+
155
+ Returns:
156
+ Padded box size
157
+ """
158
+ input_size_network = 224
159
+ default_padding = 30
160
+
161
+ if box_size >= input_size_network:
162
+ # Large detection: add default padding
 
163
  return box_size + default_padding
164
  else:
165
+ # Small detection: ensure minimum size without excessive enlargement
166
+ diff_size = input_size_network - box_size
167
+ if diff_size < default_padding:
168
+ return box_size + default_padding
169
+ else:
170
+ return input_size_network
171
+
172
+ def get_classification(self, crop: Image.Image) -> list[tuple[str, float]]:
173
+ """
174
+ Run YOLOv8 classification on cropped image.
175
+
176
+ Args:
177
+ crop: Cropped and preprocessed PIL Image
178
+
179
+ Returns:
180
+ List of (class_name, confidence) tuples for ALL classes, sorted by confidence.
181
+ Example: [("giraffe", 0.99985), ("cattle", 0.00003), ...]
182
+
183
+ Raises:
184
+ RuntimeError: If model not loaded or inference fails
185
+ """
186
+ if self.model is None:
187
+ raise RuntimeError("Model not loaded - call load_model() first")
188
+
189
+ try:
190
+ # Run YOLOv8 classification (verbose=False suppresses progress bar)
191
+ results = self.model(crop, verbose=False)
192
+
193
+ # Extract class names dict (YOLOv8 uses alphabetical order)
194
+ # Example: {0: "aardwolf", 1: "african wild cat", ..., 13: "giraffe", ...}
195
+ names_dict = results[0].names
196
+
197
+ # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
198
+ probs = results[0].probs.data.tolist()
199
+
200
+ # Build list of (class_name, confidence) tuples
201
+ # Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
202
+ classifications = []
203
+ for idx, class_name in names_dict.items():
204
+ confidence = probs[idx]
205
+ classifications.append((class_name, confidence))
206
+
207
+ # Sort by confidence descending (already sorted by YOLOv8, but ensure it)
208
+ classifications.sort(key=lambda x: x[1], reverse=True)
209
+
210
+ return classifications
211
+
212
+ except Exception as e:
213
+ raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
214
+
215
+ def get_class_names(self) -> dict[str, str]:
216
+ """
217
+ Get mapping of class IDs to species names from YOLOv8 model.
218
+
219
+ YOLOv8 stores class names in alphabetical order internally. This function
220
+ extracts those names and creates a 1-indexed mapping for the JSON format.
221
+
222
+ NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
223
+ The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
224
+
225
+ Returns:
226
+ Dict mapping class ID (1-indexed string) to common name
227
+ Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...}
228
+
229
+ Raises:
230
+ RuntimeError: If model not loaded
231
+ """
232
+ if self.model is None:
233
+ raise RuntimeError("Model not loaded - call load_model() first")
234
+
235
+ try:
236
+ # YOLOv8 names dict (alphabetical order): {0: "aardwolf", 1: "african wild cat", ...}
237
+ yolo_names = self.model.names
238
+
239
+ # Convert to 1-indexed dict for JSON compatibility
240
+ class_names = {}
241
+ for idx, name in yolo_names.items():
242
+ class_id_str = str(idx + 1) # 1-indexed
243
+ class_names[class_id_str] = name
244
+
245
+ return class_names
246
+
247
+ except Exception as e:
248
+ raise RuntimeError(f"Failed to extract class names from model: {e}") from e