Addax-Data-Science commited on
Commit
343862b
·
verified ·
1 Parent(s): 4211f6b

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +262 -0
inference.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NAM-ADS-v1 YOLOv8 Classification Model - Custom Inference Script
3
+
4
+ This script provides model-specific inference code for the Namibian Desert
5
+ species classifier (30 classes). It follows the AddaxAI-WebUI interface contract
6
+ for custom classification models.
7
+
8
+ Model: Namibian Desert v1 (namib_desert_v1.pt)
9
+ Framework: YOLOv8 (Ultralytics)
10
+ Classes: 30 species from Skeleton Coast National Park, Namibia
11
+ Training data: 850,000+ images
12
+
13
+ Original source: streamlit-AddaxAI/classification/model_types/addax-yolov8/classify_detections.py
14
+ Adapted by: Claude Code on 2026-01-11
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import pathlib
20
+ import platform
21
+ from pathlib import Path
22
+
23
+ import torch
24
+ from PIL import Image, ImageFile, ImageOps
25
+ from ultralytics import YOLO
26
+
27
+ # Module-level variables (injected by AddaxAI framework)
28
+ MODEL_DIR: Path | None = None # Set by CustomInferenceLoader
29
+ MODEL_PATH: Path | None = None # Set by CustomInferenceLoader
30
+
31
+ # Module-level model instance (loaded once at startup)
32
+ animal_model: YOLO | None = None
33
+
34
+ # Don't freak out over truncated images
35
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
36
+
37
+ # Make sure Windows-trained models work on Unix
38
+ plt = platform.system()
39
+ if plt != 'Windows':
40
+ pathlib.WindowsPath = pathlib.PosixPath
41
+
42
+
43
+ def check_gpu() -> bool:
44
+ """
45
+ Check GPU availability for YOLOv8 inference.
46
+
47
+ Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
48
+
49
+ Returns:
50
+ True if GPU available, False otherwise
51
+ """
52
+ # Check Apple MPS (Apple Silicon)
53
+ try:
54
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
55
+ return True
56
+ except Exception:
57
+ pass
58
+
59
+ # Check CUDA (NVIDIA)
60
+ return torch.cuda.is_available()
61
+
62
+
63
+ def load_model() -> None:
64
+ """
65
+ Load YOLOv8 classification model into memory.
66
+
67
+ This function is called once during worker initialization.
68
+ The model is stored in the global `animal_model` variable and reused
69
+ for all subsequent classification requests.
70
+
71
+ Raises:
72
+ RuntimeError: If model loading fails
73
+ FileNotFoundError: If MODEL_PATH is invalid
74
+ """
75
+ global animal_model, MODEL_PATH
76
+
77
+ if MODEL_PATH is None:
78
+ raise RuntimeError("MODEL_PATH not set - must be injected by framework")
79
+
80
+ if not MODEL_PATH.exists():
81
+ raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
82
+
83
+ try:
84
+ animal_model = YOLO(str(MODEL_PATH))
85
+ except Exception as e:
86
+ raise RuntimeError(f"Failed to load YOLOv8 model from {MODEL_PATH}: {e}") from e
87
+
88
+
89
+ def get_crop(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image:
90
+ """
91
+ Crop image using model-specific preprocessing.
92
+
93
+ This cropping method was developed by Dan Morris for MegaDetector and is
94
+ designed to:
95
+ 1. Square the bounding box (max of width/height)
96
+ 2. Add padding to prevent over-enlargement of small animals
97
+ 3. Center the detection within the crop
98
+ 4. Pad with black (0) to maintain square aspect ratio
99
+
100
+ Args:
101
+ image: PIL Image (full resolution)
102
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
103
+
104
+ Returns:
105
+ Cropped and padded PIL Image ready for classification
106
+
107
+ Raises:
108
+ ValueError: If bbox is invalid (zero size)
109
+ """
110
+ img_w, img_h = image.size
111
+
112
+ # Denormalize bbox coordinates
113
+ xmin = int(bbox[0] * img_w)
114
+ ymin = int(bbox[1] * img_h)
115
+ box_w = int(bbox[2] * img_w)
116
+ box_h = int(bbox[3] * img_h)
117
+
118
+ # Square the box (use max dimension)
119
+ box_size = max(box_w, box_h)
120
+
121
+ # Add padding (prevents over-enlargement of small animals)
122
+ box_size = _pad_crop(box_size)
123
+
124
+ # Center the detection within the squared crop
125
+ xmin = max(0, min(
126
+ xmin - int((box_size - box_w) / 2),
127
+ img_w - box_w
128
+ ))
129
+ ymin = max(0, min(
130
+ ymin - int((box_size - box_h) / 2),
131
+ img_h - box_h
132
+ ))
133
+
134
+ # Clip to image boundaries
135
+ box_w = min(img_w, box_size)
136
+ box_h = min(img_h, box_size)
137
+
138
+ if box_w == 0 or box_h == 0:
139
+ raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
140
+
141
+ # Crop and pad to square
142
+ crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
143
+ crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
144
+
145
+ return crop
146
+
147
+
148
+ def _pad_crop(box_size: int) -> int:
149
+ """
150
+ Calculate padded crop size to prevent over-enlargement of small animals.
151
+
152
+ YOLOv8 expects 224x224 input. This function ensures small detections aren't
153
+ excessively upscaled while adding consistent padding to larger detections.
154
+
155
+ Args:
156
+ box_size: Original bounding box size (max of width/height)
157
+
158
+ Returns:
159
+ Padded box size
160
+ """
161
+ input_size_network = 224
162
+ default_padding = 30
163
+
164
+ if box_size >= input_size_network:
165
+ # Large detection: add default padding
166
+ return box_size + default_padding
167
+ else:
168
+ # Small detection: ensure minimum size without excessive enlargement
169
+ diff_size = input_size_network - box_size
170
+ if diff_size < default_padding:
171
+ return box_size + default_padding
172
+ else:
173
+ return input_size_network
174
+
175
+
176
+ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
177
+ """
178
+ Run YOLOv8 classification on cropped image.
179
+
180
+ Args:
181
+ crop: Cropped and preprocessed PIL Image
182
+
183
+ Returns:
184
+ List of (class_id, confidence) tuples for ALL classes, sorted by confidence.
185
+ Example: [("14", 0.99985), ("7", 0.00003), ...]
186
+
187
+ Raises:
188
+ RuntimeError: If model not loaded or inference fails
189
+ """
190
+ global animal_model
191
+
192
+ if animal_model is None:
193
+ raise RuntimeError("Model not loaded - call load_model() first")
194
+
195
+ try:
196
+ # Run YOLOv8 classification (verbose=False suppresses progress bar)
197
+ results = animal_model(crop, verbose=False)
198
+
199
+ # Extract class names dict: {0: "porcupine", 1: "elephant", ...}
200
+ names_dict = results[0].names
201
+
202
+ # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
203
+ probs = results[0].probs.data.tolist()
204
+
205
+ # Build list of (class_id, confidence) tuples
206
+ # Class IDs are 0-indexed in YOLOv8 but we output as strings
207
+ classifications = []
208
+ for idx, class_name in names_dict.items():
209
+ # YOLOv8 uses 0-based indexing, but we need 1-based for compatibility
210
+ # with the taxonomy.csv and expected JSON output
211
+ class_id_str = str(idx + 1) # Convert 0-indexed to 1-indexed
212
+ confidence = probs[idx]
213
+ classifications.append((class_id_str, confidence))
214
+
215
+ # Sort by confidence descending (already sorted by YOLOv8, but ensure it)
216
+ classifications.sort(key=lambda x: x[1], reverse=True)
217
+
218
+ return classifications
219
+
220
+ except Exception as e:
221
+ raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
222
+
223
+
224
+ def get_class_names() -> dict[str, str]:
225
+ """
226
+ Get mapping of class IDs to species names from taxonomy.csv.
227
+
228
+ Reads taxonomy.csv from the model directory and extracts the model_class
229
+ (common name) for each species.
230
+
231
+ Returns:
232
+ Dict mapping class ID (1-indexed string) to common name
233
+ Example: {"1": "porcupine", "2": "elephant", ...}
234
+
235
+ Raises:
236
+ FileNotFoundError: If taxonomy.csv not found
237
+ RuntimeError: If parsing fails
238
+ """
239
+ global animal_model, MODEL_DIR
240
+
241
+ if MODEL_DIR is None:
242
+ raise RuntimeError("MODEL_DIR not set - must be injected by framework")
243
+
244
+ # YOLOv8 models have class names built-in
245
+ # We'll use those directly since they should match taxonomy.csv
246
+ if animal_model is None:
247
+ raise RuntimeError("Model not loaded - call load_model() first")
248
+
249
+ try:
250
+ # YOLOv8 names dict: {0: "porcupine", 1: "elephant", ...}
251
+ yolo_names = animal_model.names
252
+
253
+ # Convert to 1-indexed dict for compatibility
254
+ class_names = {}
255
+ for idx, name in yolo_names.items():
256
+ class_id_str = str(idx + 1) # 1-indexed
257
+ class_names[class_id_str] = name
258
+
259
+ return class_names
260
+
261
+ except Exception as e:
262
+ raise RuntimeError(f"Failed to extract class names: {e}") from e