| """ |
| Inference script for SAH-DRY-ADS-v1 (Sub-Saharan Drylands Species Classifier) |
| |
| This model classifies 328 categories across eastern and southern African ecosystems, |
| with taxonomic fallback for uncertain species-level predictions. Trained on 2.8+ million |
| camera trap images from savannas, dry forests, arid shrublands, and semi-desert habitats |
| across 9 countries. All training data is open-source via LILA BC (https://lila.science/). |
| |
| Model: Sub-Saharan Drylands v1 |
| Input: Variable size (extracted from checkpoint, typically 480x480) |
| Framework: PyTorch (EfficientNet V2 Medium architecture) |
| Classes: 328 species and higher-level taxa with taxonomic fallback |
| Developer: Addax Data Science |
| Citation: https://joss.theoj.org/papers/10.21105/joss.05581 |
| License: CC BY-NC-SA 4.0 |
| Info: https://addaxdatascience.com/ |
| |
| Training regions: South Africa, Tanzania, Kenya, Mozambique, Botswana, Namibia, |
| Rwanda, Madagascar, Uganda |
| |
| Author: Peter van Lunteren |
| Created: 2026-01-14 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import pathlib |
| import platform |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image, ImageFile, ImageOps |
| from torchvision import transforms |
| from torchvision.models import efficientnet |
|
|
| |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| |
| plt = platform.system() |
| if plt != 'Windows': |
| pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
| class EfficientNetV2M(nn.Module): |
| """EfficientNet V2 Medium architecture for wildlife classification.""" |
|
|
| def __init__(self, num_classes: int, tune: bool = True): |
| super(EfficientNetV2M, self).__init__() |
| self.avgpool = nn.AdaptiveAvgPool2d(1) |
| self.model = efficientnet.efficientnet_v2_m( |
| weights=efficientnet.EfficientNet_V2_M_Weights.DEFAULT |
| ) |
| if tune: |
| for params in self.model.parameters(): |
| params.requires_grad = True |
| num_ftrs = self.model.classifier[1].in_features |
| self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes) |
|
|
| def forward(self, x): |
| x = self.model.features(x) |
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
| prediction = self.model.classifier(x) |
| return prediction |
|
|
| class ModelInference: |
| """PyTorch inference implementation for Sub-Saharan Drylands species classifier.""" |
|
|
| def __init__(self, model_dir: Path, model_path: Path): |
| """ |
| Initialize with model paths. |
| |
| Args: |
| model_dir: Directory containing model files |
| model_path: Path to sub_saharan_drylands_v1.pt checkpoint file |
| """ |
| self.model_dir = model_dir |
| self.model_path = model_path |
| self.model = None |
| self.device = None |
| self.image_size = None |
| self.classes = [] |
| self.preprocess = None |
|
|
| def check_gpu(self) -> bool: |
| """ |
| Check GPU availability for PyTorch inference. |
| |
| Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. |
| |
| Returns: |
| True if GPU available, False otherwise |
| """ |
| |
| try: |
| if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| return True |
| except Exception: |
| pass |
|
|
| |
| return torch.cuda.is_available() |
|
|
| def load_model(self, device_str: str = 'cpu') -> None: |
| """ |
| Load PyTorch model from checkpoint. |
| |
| The checkpoint contains: |
| - model: State dict with trained weights |
| - categories: Dict mapping class names to indices |
| - image_size: Tuple with input dimensions |
| |
| Args: |
| device_str: Device to load model on ('cpu', 'cuda', or 'mps') |
| |
| Raises: |
| RuntimeError: If model loading fails |
| FileNotFoundError: If model_path is invalid |
| """ |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
| try: |
| |
| self.device = torch.device(device_str) |
|
|
| |
| checkpoint = torch.load(str(self.model_path), map_location=self.device) |
|
|
| |
| self.image_size = tuple(checkpoint['image_size']) |
| categories = checkpoint['categories'] |
| self.classes = list(categories.keys()) |
|
|
| |
| num_classes = len(self.classes) |
| self.model = EfficientNetV2M(num_classes, tune=False) |
|
|
| |
| self.model.load_state_dict(checkpoint['model']) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| |
| self.preprocess = transforms.Compose([ |
| transforms.Resize(self.image_size), |
| transforms.ToTensor(), |
| ]) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e |
|
|
| def get_crop( |
| self, image: Image.Image, bbox: tuple[float, float, float, float] |
| ) -> Image.Image: |
| """ |
| Crop image using model-specific preprocessing. |
| |
| This cropping method was developed by Dan Morris for MegaDetector and is |
| designed to: |
| 1. Square the bounding box (max of width/height) |
| 2. Add padding to prevent over-enlargement of small animals |
| 3. Center the detection within the crop |
| 4. Pad with black (0) to maintain square aspect ratio |
| |
| Args: |
| image: PIL Image (full resolution) |
| bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
| |
| Returns: |
| Cropped and padded PIL Image ready for classification |
| |
| Raises: |
| ValueError: If bbox is invalid (zero size) |
| """ |
| img_w, img_h = image.size |
|
|
| |
| xmin = int(bbox[0] * img_w) |
| ymin = int(bbox[1] * img_h) |
| box_w = int(bbox[2] * img_w) |
| box_h = int(bbox[3] * img_h) |
|
|
| |
| box_size = max(box_w, box_h) |
|
|
| |
| box_size = self._pad_crop(box_size) |
|
|
| |
| xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w)) |
| ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h)) |
|
|
| |
| box_w = min(img_w, box_size) |
| box_h = min(img_h, box_size) |
|
|
| if box_w == 0 or box_h == 0: |
| raise ValueError(f"Invalid bbox size: {box_w}x{box_h}") |
|
|
| |
| crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h]) |
| crop = ImageOps.pad(crop, size=(box_size, box_size), color=0) |
|
|
| return crop |
|
|
| def _pad_crop(self, box_size: int) -> int: |
| """ |
| Calculate padded crop size to prevent over-enlargement of small animals. |
| |
| Standard network input is 224x224. This function ensures small detections |
| aren't excessively upscaled while adding consistent padding to larger detections. |
| |
| Args: |
| box_size: Original bounding box size (max of width/height) |
| |
| Returns: |
| Padded box size |
| """ |
| input_size_network = 224 |
| default_padding = 30 |
|
|
| if box_size >= input_size_network: |
| |
| return box_size + default_padding |
| else: |
| |
| diff_size = input_size_network - box_size |
| if diff_size < default_padding: |
| return box_size + default_padding |
| else: |
| return input_size_network |
|
|
| def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
| """ |
| Run PyTorch classification on cropped image. |
| |
| Args: |
| crop: Cropped and preprocessed PIL Image |
| |
| Returns: |
| List of [class_name, confidence] lists for ALL classes, in model order. |
| Example: [["lion", 0.85], ["leopard", 0.10], ["cheetah", 0.02], ...] |
| NOTE: Sorting by confidence is handled by classification_worker.py |
| |
| Raises: |
| RuntimeError: If model not loaded or inference fails |
| """ |
| if self.model is None: |
| raise RuntimeError("Model not loaded - call load_model() first") |
|
|
| try: |
| |
| input_tensor = self.preprocess(crop) |
| input_batch = input_tensor.unsqueeze(0) |
| input_batch = input_batch.to(self.device) |
|
|
| |
| with torch.no_grad(): |
| output = self.model(input_batch) |
|
|
| |
| probabilities = F.softmax(output, dim=1) |
| probabilities_np = probabilities.cpu().detach().numpy() |
| confidence_scores = probabilities_np[0] |
|
|
| |
| classifications = [] |
| for i in range(len(confidence_scores)): |
| pred_class = self.classes[i] |
| pred_conf = float(confidence_scores[i]) |
| classifications.append([pred_class, pred_conf]) |
|
|
| return classifications |
|
|
| except Exception as e: |
| raise RuntimeError(f"PyTorch classification failed: {e}") from e |
|
|
| def get_class_names(self) -> dict[str, str]: |
| """ |
| Get mapping of class IDs to species names. |
| |
| Returns: |
| Dict mapping class ID (1-indexed string) to species/taxon name |
| Example: {"1": "aardvark", "2": "african wild cat", ..., "328": "zebra"} |
| |
| Raises: |
| RuntimeError: If model not loaded |
| """ |
| if self.model is None: |
| raise RuntimeError("Model not loaded - call load_model() first") |
|
|
| try: |
| |
| class_names = {} |
| for i, class_name in enumerate(self.classes): |
| class_id_str = str(i + 1) |
| class_names[class_id_str] = class_name |
|
|
| return class_names |
|
|
| except Exception as e: |
| raise RuntimeError(f"Failed to extract class names: {e}") from e |
|
|
| def get_tensor(self, crop: Image.Image): |
| """Preprocess a crop into a numpy array for batch inference.""" |
| tensor = self.preprocess(crop) |
| return tensor.numpy() |
|
|
| def classify_batch(self, batch): |
| """Run inference on a batch of preprocessed numpy arrays.""" |
| tensor = torch.from_numpy(batch).to(self.device) |
| with torch.no_grad(): |
| output = self.model(tensor) |
| probs = F.softmax(output, dim=1).cpu().numpy() |
|
|
| results = [] |
| for p in probs: |
| classifications = [ |
| [self.classes[i], float(p[i])] |
| for i in range(len(self.classes)) |
| ] |
| results.append(classifications) |
| return results |
|
|