NAM-ADS-v1 / inference.py
Addax-Data-Science's picture
Upload inference.py
2243498 verified
"""
Inference script for NAM-ADS-v1 (Namibian Desert Species Classifier)
This model identifies 30 species or higher level taxons present in the desert biome
of the Skeleton Coast National Park, North Namibia. Trained on 850,000+ camera trap images.
Model: Namibian Desert v1
Input: 640x640 RGB images
Framework: PyTorch (YOLOv8 classification)
Classes: 30 desert-adapted species and taxonomic groups
Developer: Addax Data Science
Citation: https://joss.theoj.org/papers/10.21105/joss.05581
License: CC BY-NC-SA 4.0
Info: https://addaxdatascience.com/projects/2023-01-dlc/
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import pathlib
import platform
from pathlib import Path
import torch
from PIL import Image, ImageFile, ImageOps
from ultralytics import YOLO
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class ModelInference:
"""YOLOv8 inference implementation for Namibian Desert species classifier."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to namib_desert_v1.pt file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model: YOLO | None = None
def check_gpu(self) -> bool:
"""
Check GPU availability for YOLOv8 inference.
Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
Returns:
True if GPU available, False otherwise
"""
# Check Apple MPS (Apple Silicon)
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
# Check CUDA (NVIDIA)
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load YOLOv8 classification model into memory.
This function is called once during worker initialization.
The model is stored in self.model and reused for all subsequent
classification requests.
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path is invalid
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
self.model = YOLO(str(self.model_path))
except Exception as e:
raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using model-specific preprocessing.
This cropping method was developed by Dan Morris for MegaDetector and is
designed to:
1. Square the bounding box (max of width/height)
2. Add padding to prevent over-enlargement of small animals
3. Center the detection within the crop
4. Pad with black (0) to maintain square aspect ratio
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped and padded PIL Image ready for classification
Raises:
ValueError: If bbox is invalid (zero size)
"""
img_w, img_h = image.size
# Denormalize bbox coordinates
xmin = int(bbox[0] * img_w)
ymin = int(bbox[1] * img_h)
box_w = int(bbox[2] * img_w)
box_h = int(bbox[3] * img_h)
# Square the box (use max dimension)
box_size = max(box_w, box_h)
# Add padding (prevents over-enlargement of small animals)
box_size = self._pad_crop(box_size)
# Center the detection within the squared crop
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
# Clip to image boundaries
box_w = min(img_w, box_size)
box_h = min(img_h, box_size)
if box_w == 0 or box_h == 0:
raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
# Crop and pad to square
crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
return crop
def _pad_crop(self, box_size: int) -> int:
"""
Calculate padded crop size to prevent over-enlargement of small animals.
YOLOv8 expects 224x224 input. This function ensures small detections aren't
excessively upscaled while adding consistent padding to larger detections.
Args:
box_size: Original bounding box size (max of width/height)
Returns:
Padded box size
"""
input_size_network = 224
default_padding = 30
if box_size >= input_size_network:
# Large detection: add default padding
return box_size + default_padding
else:
# Small detection: ensure minimum size without excessive enlargement
diff_size = input_size_network - box_size
if diff_size < default_padding:
return box_size + default_padding
else:
return input_size_network
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run YOLOv8 classification on cropped image.
Args:
crop: Cropped and preprocessed PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes, in model order.
Example: [["aardwolf", 0.00001], ["giraffe", 0.99985], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Run YOLOv8 classification (verbose=False suppresses progress bar)
results = self.model(crop, verbose=False)
# Extract class names dict (YOLOv8 uses alphabetical order)
# Example: {0: "aardwolf", 1: "african wild cat", ..., 13: "giraffe", ...}
names_dict = results[0].names
# Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
probs = results[0].probs.data.tolist()
# Build list of [class_name, confidence] pairs (as lists, not tuples!)
# Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
classifications = []
for idx, class_name in names_dict.items():
confidence = probs[idx]
classifications.append([class_name, confidence])
# NOTE: Sorting by confidence is handled by classification_worker.py
# Model developers don't need to sort - just return all class predictions
return classifications
except Exception as e:
raise RuntimeError(f"YOLOv8 classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to species names from YOLOv8 model.
YOLOv8 stores class names in alphabetical order internally. This function
extracts those names and creates a 1-indexed mapping for the JSON format.
NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.
Returns:
Dict mapping class ID (1-indexed string) to common name
Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...}
Raises:
RuntimeError: If model not loaded
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# YOLOv8 names dict (alphabetical order): {0: "aardwolf", 1: "african wild cat", ...}
yolo_names = self.model.names
# Convert to 1-indexed dict for JSON compatibility
class_names = {}
for idx, name in yolo_names.items():
class_id_str = str(idx + 1) # 1-indexed
class_names[class_id_str] = name
return class_names
except Exception as e:
raise RuntimeError(f"Failed to extract class names from model: {e}") from e