File size: 8,826 Bytes
c9d39e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
"""
Inference script for TKM-ADS-v1 (Turkmenistan Species Classifier)
This model identifies 14 species or higher-level taxons present in Southern Turkmenistan.
Trained on ~1 million camera trap images achieving 95% validation accuracy, 93% precision,
and 94% recall. Note: Accuracy not tested on out-of-sample local dataset as local images
were not available.
Model: Turkmenistan v1
Input: 640x640 RGB images
Framework: PyTorch (YOLOv8 classification)
Classes: 14 species and taxonomic groups
Developer: Addax Data Science
Citation: https://joss.theoj.org/papers/10.21105/joss.05581
License: CC BY-NC-SA 4.0
Info: https://addaxdatascience.com/
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import pathlib
import platform
from pathlib import Path
import torch
from PIL import Image, ImageFile, ImageOps
from ultralytics import YOLO
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class ModelInference:
"""YOLOv8 inference implementation for Turkmenistan species classifier."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to tkm_v1.pt file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model: YOLO | None = None
def check_gpu(self) -> bool:
"""
Check GPU availability for YOLOv8 inference.
Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
Returns:
True if GPU available, False otherwise
"""
# Check Apple MPS (Apple Silicon)
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
# Check CUDA (NVIDIA)
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load YOLOv8 classification model into memory.
This function is called once during worker initialization.
The model is stored in self.model and reused for all subsequent
classification requests.
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path is invalid
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
self.model = YOLO(str(self.model_path))
except Exception as e:
raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using model-specific preprocessing.
This cropping method was developed by Dan Morris for MegaDetector and is
designed to:
1. Square the bounding box (max of width/height)
2. Add padding to prevent over-enlargement of small animals
3. Center the detection within the crop
4. Pad with black (0) to maintain square aspect ratio
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped and padded PIL Image ready for classification
Raises:
ValueError: If bbox is invalid (zero size)
"""
img_w, img_h = image.size
# Denormalize bbox coordinates
xmin = int(bbox[0] * img_w)
ymin = int(bbox[1] * img_h)
box_w = int(bbox[2] * img_w)
box_h = int(bbox[3] * img_h)
# Square the box (use max dimension)
box_size = max(box_w, box_h)
# Add padding (prevents over-enlargement of small animals)
box_size = self._pad_crop(box_size)
# Center the detection within the squared crop
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
# Clip to image boundaries
box_w = min(img_w, box_size)
box_h = min(img_h, box_size)
if box_w == 0 or box_h == 0:
raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
# Crop and pad to square
crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
return crop
def _pad_crop(self, box_size: int) -> int:
"""
Calculate padded crop size to prevent over-enlargement of small animals.
YOLOv8 expects 224x224 input. This function ensures small detections aren't
excessively upscaled while adding consistent padding to larger detections.
Args:
box_size: Original bounding box size (max of width/height)
Returns:
Padded box size
"""
input_size_network = 224
default_padding = 30
if box_size >= input_size_network:
# Large detection: add default padding
return box_size + default_padding
else:
# Small detection: ensure minimum size without excessive enlargement
diff_size = input_size_network - box_size
if diff_size < default_padding:
return box_size + default_padding
else:
return input_size_network
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run YOLOv8 classification on cropped image.
Args:
crop: Cropped and preprocessed PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes, in model order.
Example: [["goitered gazelle", 0.92], ["urial", 0.05], ["wolf", 0.02], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Run YOLOv8 classification (verbose=False suppresses progress bar)
results = self.model(crop, verbose=False)
# Extract class names dict (YOLOv8 uses alphabetical order)
# Example: {0: "bird", 1: "goitered gazelle", ..., 13: "wolf"}
names_dict = results[0].names
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
probs = results[0].probs.data.tolist()
# Build list of [class_name, confidence] pairs (as lists, not tuples!)
# Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
classifications = []
for idx, class_name in names_dict.items():
confidence = probs[idx]
classifications.append([class_name, confidence])
# NOTE: Sorting by confidence is handled by classification_worker.py
# Model developers don't need to sort - just return all class predictions
return classifications
except Exception as e:
raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to species names from YOLOv8 model.
YOLOv8 stores class names in alphabetical order internally. This function
extracts those names and creates a 1-indexed mapping for the JSON format.
NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
Returns:
Dict mapping class ID (1-indexed string) to common name
Example: {"1": "bird", "2": "goitered gazelle", ..., "14": "wolf"}
Raises:
RuntimeError: If model not loaded
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# YOLOv8 names dict (alphabetical order): {0: "bird", 1: "goitered gazelle", ...}
yolo_names = self.model.names
# Convert to 1-indexed dict for JSON compatibility
class_names = {}
for idx, name in yolo_names.items():
class_id_str = str(idx + 1) # 1-indexed
class_names[class_id_str] = name
return class_names
except Exception as e:
raise RuntimeError(f"Failed to extract class names from model: {e}") from e
|