""" Inference script for SAH-DRY-ADS-v1 (Sub-Saharan Drylands Species Classifier) This model classifies 328 categories across eastern and southern African ecosystems, with taxonomic fallback for uncertain species-level predictions. Trained on 2.8+ million camera trap images from savannas, dry forests, arid shrublands, and semi-desert habitats across 9 countries. All training data is open-source via LILA BC (https://lila.science/). Model: Sub-Saharan Drylands v1 Input: Variable size (extracted from checkpoint, typically 480x480) Framework: PyTorch (EfficientNet V2 Medium architecture) Classes: 328 species and higher-level taxa with taxonomic fallback Developer: Addax Data Science Citation: https://joss.theoj.org/papers/10.21105/joss.05581 License: CC BY-NC-SA 4.0 Info: https://addaxdatascience.com/ Training regions: South Africa, Tanzania, Kenya, Mozambique, Botswana, Namibia, Rwanda, Madagascar, Uganda Author: Peter van Lunteren Created: 2026-01-14 """ from __future__ import annotations import pathlib import platform from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image, ImageFile, ImageOps from torchvision import transforms from torchvision.models import efficientnet # Don't freak out over truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath class EfficientNetV2M(nn.Module): """EfficientNet V2 Medium architecture for wildlife classification.""" def __init__(self, num_classes: int, tune: bool = True): super(EfficientNetV2M, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) self.model = efficientnet.efficientnet_v2_m( weights=efficientnet.EfficientNet_V2_M_Weights.DEFAULT ) if tune: for params in self.model.parameters(): params.requires_grad = True num_ftrs = self.model.classifier[1].in_features self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes) def forward(self, x): x = self.model.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) prediction = self.model.classifier(x) return prediction class ModelInference: """PyTorch inference implementation for Sub-Saharan Drylands species classifier.""" def __init__(self, model_dir: Path, model_path: Path): """ Initialize with model paths. Args: model_dir: Directory containing model files model_path: Path to sub_saharan_drylands_v1.pt checkpoint file """ self.model_dir = model_dir self.model_path = model_path self.model = None self.device = None self.image_size = None self.classes = [] self.preprocess = None def check_gpu(self) -> bool: """ Check GPU availability for PyTorch inference. Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. Returns: True if GPU available, False otherwise """ # Check Apple MPS (Apple Silicon) try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass # Check CUDA (NVIDIA) return torch.cuda.is_available() def load_model(self, device_str: str = 'cpu') -> None: """ Load PyTorch model from checkpoint. The checkpoint contains: - model: State dict with trained weights - categories: Dict mapping class names to indices - image_size: Tuple with input dimensions Args: device_str: Device to load model on ('cpu', 'cuda', or 'mps') Raises: RuntimeError: If model loading fails FileNotFoundError: If model_path is invalid """ if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") try: # Set device self.device = torch.device(device_str) # Load checkpoint checkpoint = torch.load(str(self.model_path), map_location=self.device) # Extract metadata self.image_size = tuple(checkpoint['image_size']) categories = checkpoint['categories'] self.classes = list(categories.keys()) # Initialize EfficientNet V2 Medium architecture num_classes = len(self.classes) self.model = EfficientNetV2M(num_classes, tune=False) # Load weights self.model.load_state_dict(checkpoint['model']) self.model.to(self.device) self.model.eval() # Setup preprocessing self.preprocess = transforms.Compose([ transforms.Resize(self.image_size), transforms.ToTensor(), ]) except Exception as e: raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """ Crop image using model-specific preprocessing. This cropping method was developed by Dan Morris for MegaDetector and is designed to: 1. Square the bounding box (max of width/height) 2. Add padding to prevent over-enlargement of small animals 3. Center the detection within the crop 4. Pad with black (0) to maintain square aspect ratio Args: image: PIL Image (full resolution) bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] Returns: Cropped and padded PIL Image ready for classification Raises: ValueError: If bbox is invalid (zero size) """ img_w, img_h = image.size # Denormalize bbox coordinates xmin = int(bbox[0] * img_w) ymin = int(bbox[1] * img_h) box_w = int(bbox[2] * img_w) box_h = int(bbox[3] * img_h) # Square the box (use max dimension) box_size = max(box_w, box_h) # Add padding (prevents over-enlargement of small animals) box_size = self._pad_crop(box_size) # Center the detection within the squared crop xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w)) ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h)) # Clip to image boundaries box_w = min(img_w, box_size) box_h = min(img_h, box_size) if box_w == 0 or box_h == 0: raise ValueError(f"Invalid bbox size: {box_w}x{box_h}") # Crop and pad to square crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h]) crop = ImageOps.pad(crop, size=(box_size, box_size), color=0) return crop def _pad_crop(self, box_size: int) -> int: """ Calculate padded crop size to prevent over-enlargement of small animals. Standard network input is 224x224. This function ensures small detections aren't excessively upscaled while adding consistent padding to larger detections. Args: box_size: Original bounding box size (max of width/height) Returns: Padded box size """ input_size_network = 224 default_padding = 30 if box_size >= input_size_network: # Large detection: add default padding return box_size + default_padding else: # Small detection: ensure minimum size without excessive enlargement diff_size = input_size_network - box_size if diff_size < default_padding: return box_size + default_padding else: return input_size_network def get_classification(self, crop: Image.Image) -> list[list[str, float]]: """ Run PyTorch classification on cropped image. Args: crop: Cropped and preprocessed PIL Image Returns: List of [class_name, confidence] lists for ALL classes, in model order. Example: [["lion", 0.85], ["leopard", 0.10], ["cheetah", 0.02], ...] NOTE: Sorting by confidence is handled by classification_worker.py Raises: RuntimeError: If model not loaded or inference fails """ if self.model is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Preprocess image (resize and convert to tensor) input_tensor = self.preprocess(crop) input_batch = input_tensor.unsqueeze(0) # Add batch dimension input_batch = input_batch.to(self.device) # Run inference with torch.no_grad(): output = self.model(input_batch) # Apply softmax to get probabilities probabilities = F.softmax(output, dim=1) probabilities_np = probabilities.cpu().detach().numpy() confidence_scores = probabilities_np[0] # Build list of [class_name, confidence] pairs classifications = [] for i in range(len(confidence_scores)): pred_class = self.classes[i] pred_conf = float(confidence_scores[i]) classifications.append([pred_class, pred_conf]) return classifications except Exception as e: raise RuntimeError(f"PyTorch classification failed: {e}") from e def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to species names. Returns: Dict mapping class ID (1-indexed string) to species/taxon name Example: {"1": "aardvark", "2": "african wild cat", ..., "328": "zebra"} Raises: RuntimeError: If model not loaded """ if self.model is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Create 1-indexed mapping of class IDs to names class_names = {} for i, class_name in enumerate(self.classes): class_id_str = str(i + 1) # 1-indexed class_names[class_id_str] = class_name return class_names except Exception as e: raise RuntimeError(f"Failed to extract class names: {e}") from e def get_tensor(self, crop: Image.Image): """Preprocess a crop into a numpy array for batch inference.""" tensor = self.preprocess(crop) return tensor.numpy() def classify_batch(self, batch): """Run inference on a batch of preprocessed numpy arrays.""" tensor = torch.from_numpy(batch).to(self.device) with torch.no_grad(): output = self.model(tensor) probs = F.softmax(output, dim=1).cpu().numpy() results = [] for p in probs: classifications = [ [self.classes[i], float(p[i])] for i in range(len(self.classes)) ] results.append(classifications) return results