SAH-DRY-ADS-v1 / inference.py
Addax-Data-Science's picture
Upload inference.py
4512f1d verified
"""
Inference script for SAH-DRY-ADS-v1 (Sub-Saharan Drylands Species Classifier)
This model classifies 328 categories across eastern and southern African ecosystems,
with taxonomic fallback for uncertain species-level predictions. Trained on 2.8+ million
camera trap images from savannas, dry forests, arid shrublands, and semi-desert habitats
across 9 countries. All training data is open-source via LILA BC (https://lila.science/).
Model: Sub-Saharan Drylands v1
Input: Variable size (extracted from checkpoint, typically 480x480)
Framework: PyTorch (EfficientNet V2 Medium architecture)
Classes: 328 species and higher-level taxa with taxonomic fallback
Developer: Addax Data Science
Citation: https://joss.theoj.org/papers/10.21105/joss.05581
License: CC BY-NC-SA 4.0
Info: https://addaxdatascience.com/
Training regions: South Africa, Tanzania, Kenya, Mozambique, Botswana, Namibia,
Rwanda, Madagascar, Uganda
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import pathlib
import platform
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageFile, ImageOps
from torchvision import transforms
from torchvision.models import efficientnet
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class EfficientNetV2M(nn.Module):
"""EfficientNet V2 Medium architecture for wildlife classification."""
def __init__(self, num_classes: int, tune: bool = True):
super(EfficientNetV2M, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.model = efficientnet.efficientnet_v2_m(
weights=efficientnet.EfficientNet_V2_M_Weights.DEFAULT
)
if tune:
for params in self.model.parameters():
params.requires_grad = True
num_ftrs = self.model.classifier[1].in_features
self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes)
def forward(self, x):
x = self.model.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
prediction = self.model.classifier(x)
return prediction
class ModelInference:
"""PyTorch inference implementation for Sub-Saharan Drylands species classifier."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to sub_saharan_drylands_v1.pt checkpoint file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.device = None
self.image_size = None
self.classes = []
self.preprocess = None
def check_gpu(self) -> bool:
"""
Check GPU availability for PyTorch inference.
Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.
Returns:
True if GPU available, False otherwise
"""
# Check Apple MPS (Apple Silicon)
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
# Check CUDA (NVIDIA)
return torch.cuda.is_available()
def load_model(self, device_str: str = 'cpu') -> None:
"""
Load PyTorch model from checkpoint.
The checkpoint contains:
- model: State dict with trained weights
- categories: Dict mapping class names to indices
- image_size: Tuple with input dimensions
Args:
device_str: Device to load model on ('cpu', 'cuda', or 'mps')
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path is invalid
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
# Set device
self.device = torch.device(device_str)
# Load checkpoint
checkpoint = torch.load(str(self.model_path), map_location=self.device)
# Extract metadata
self.image_size = tuple(checkpoint['image_size'])
categories = checkpoint['categories']
self.classes = list(categories.keys())
# Initialize EfficientNet V2 Medium architecture
num_classes = len(self.classes)
self.model = EfficientNetV2M(num_classes, tune=False)
# Load weights
self.model.load_state_dict(checkpoint['model'])
self.model.to(self.device)
self.model.eval()
# Setup preprocessing
self.preprocess = transforms.Compose([
transforms.Resize(self.image_size),
transforms.ToTensor(),
])
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using model-specific preprocessing.
This cropping method was developed by Dan Morris for MegaDetector and is
designed to:
1. Square the bounding box (max of width/height)
2. Add padding to prevent over-enlargement of small animals
3. Center the detection within the crop
4. Pad with black (0) to maintain square aspect ratio
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped and padded PIL Image ready for classification
Raises:
ValueError: If bbox is invalid (zero size)
"""
img_w, img_h = image.size
# Denormalize bbox coordinates
xmin = int(bbox[0] * img_w)
ymin = int(bbox[1] * img_h)
box_w = int(bbox[2] * img_w)
box_h = int(bbox[3] * img_h)
# Square the box (use max dimension)
box_size = max(box_w, box_h)
# Add padding (prevents over-enlargement of small animals)
box_size = self._pad_crop(box_size)
# Center the detection within the squared crop
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))
# Clip to image boundaries
box_w = min(img_w, box_size)
box_h = min(img_h, box_size)
if box_w == 0 or box_h == 0:
raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")
# Crop and pad to square
crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)
return crop
def _pad_crop(self, box_size: int) -> int:
"""
Calculate padded crop size to prevent over-enlargement of small animals.
Standard network input is 224x224. This function ensures small detections
aren't excessively upscaled while adding consistent padding to larger detections.
Args:
box_size: Original bounding box size (max of width/height)
Returns:
Padded box size
"""
input_size_network = 224
default_padding = 30
if box_size >= input_size_network:
# Large detection: add default padding
return box_size + default_padding
else:
# Small detection: ensure minimum size without excessive enlargement
diff_size = input_size_network - box_size
if diff_size < default_padding:
return box_size + default_padding
else:
return input_size_network
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run PyTorch classification on cropped image.
Args:
crop: Cropped and preprocessed PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes, in model order.
Example: [["lion", 0.85], ["leopard", 0.10], ["cheetah", 0.02], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Preprocess image (resize and convert to tensor)
input_tensor = self.preprocess(crop)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
input_batch = input_batch.to(self.device)
# Run inference
with torch.no_grad():
output = self.model(input_batch)
# Apply softmax to get probabilities
probabilities = F.softmax(output, dim=1)
probabilities_np = probabilities.cpu().detach().numpy()
confidence_scores = probabilities_np[0]
# Build list of [class_name, confidence] pairs
classifications = []
for i in range(len(confidence_scores)):
pred_class = self.classes[i]
pred_conf = float(confidence_scores[i])
classifications.append([pred_class, pred_conf])
return classifications
except Exception as e:
raise RuntimeError(f"PyTorch classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to species names.
Returns:
Dict mapping class ID (1-indexed string) to species/taxon name
Example: {"1": "aardvark", "2": "african wild cat", ..., "328": "zebra"}
Raises:
RuntimeError: If model not loaded
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Create 1-indexed mapping of class IDs to names
class_names = {}
for i, class_name in enumerate(self.classes):
class_id_str = str(i + 1) # 1-indexed
class_names[class_id_str] = class_name
return class_names
except Exception as e:
raise RuntimeError(f"Failed to extract class names: {e}") from e
def get_tensor(self, crop: Image.Image):
"""Preprocess a crop into a numpy array for batch inference."""
tensor = self.preprocess(crop)
return tensor.numpy()
def classify_batch(self, batch):
"""Run inference on a batch of preprocessed numpy arrays."""
tensor = torch.from_numpy(batch).to(self.device)
with torch.no_grad():
output = self.model(tensor)
probs = F.softmax(output, dim=1).cpu().numpy()
results = []
for p in probs:
classifications = [
[self.classes[i], float(p[i])]
for i in range(len(self.classes))
]
results.append(classifications)
return results