""" Inference script for SWUSA-SDZWA-v3 (Southwest USA Species Classifier) This model distinguishes between 27 species native to the Southwest United States. Training data collected by SDZWA and California Mountain Lion Project, with examples from NACTI and CCT datasets. Trained on 91,662 images (70/20/10 split) achieving 88% accuracy on test set. Model: Southwest USA v3 Input: 299x299 RGB images Framework: PyTorch (EfficientNet V2 Medium architecture) Classes: 27 species and categories Developer: San Diego Zoo Wildlife Alliance (Kyra Swanson) License: MIT Info: https://github.com/conservationtechlab Author: Peter van Lunteren Created: 2026-01-14 """ from __future__ import annotations import pathlib import platform from pathlib import Path import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image, ImageFile from torchvision import transforms from torchvision.models import efficientnet # Don't freak out over truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath class EfficientNetV2M(nn.Module): """EfficientNet V2 Medium architecture for SDZWA wildlife classification.""" def __init__( self, num_classes: int, pretrained_weights_path: Path, device_str: str = 'cpu', tune: bool = True ): """ Initialize EfficientNet V2 Medium model. Args: num_classes: Number of output classes pretrained_weights_path: Path to ImageNet pretrained weights (.pth file) device_str: Device to load model on ('cpu', 'cuda', 'mps') tune: Whether to enable gradient updates (fine-tuning) """ super(EfficientNetV2M, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) # Load EfficientNet V2 Medium with ImageNet weights self.model = efficientnet.efficientnet_v2_m(weights=None) self.model.load_state_dict( torch.load(str(pretrained_weights_path), map_location=torch.device(device_str)) ) # Enable/disable gradient computation if tune: for params in self.model.parameters(): params.requires_grad = True # Replace classifier head with custom layer num_ftrs = self.model.classifier[1].in_features self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes) self.model.to(torch.device(device_str)) def forward(self, x): """Forward pass (prediction).""" x = self.model.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) prediction = self.model.classifier(x) return prediction class ModelInference: """PyTorch inference implementation for Southwest USA species classifier.""" def __init__(self, model_dir: Path, model_path: Path): """ Initialize with model paths. Args: model_dir: Directory containing model files (classes.csv, pretrained weights) model_path: Path to southwest_v3.pt checkpoint file """ self.model_dir = model_dir self.model_path = model_path self.model = None self.device = None self.classes = None self.preprocess = None def check_gpu(self) -> bool: """ Check GPU availability for PyTorch inference. Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. Returns: True if GPU available, False otherwise """ # Check Apple MPS (Apple Silicon) try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass # Check CUDA (NVIDIA) return torch.cuda.is_available() def load_model(self, device_str: str = 'cpu') -> None: """ Load PyTorch EfficientNet model and class labels. This SDZWA model uses EfficientNet V2 Medium architecture with ImageNet pretrained weights, fine-tuned on Southwest USA wildlife data. Args: device_str: Device to load model on ('cpu', 'cuda', or 'mps') Raises: RuntimeError: If model loading fails FileNotFoundError: If required files are missing """ if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") # Check for required files classes_csv = self.model_dir / 'classes.csv' efficientnet_weights = self.model_dir / 'efficientnet_v2_m-dc08266a.pth' if not classes_csv.exists(): raise FileNotFoundError(f"Classes file not found: {classes_csv}") if not efficientnet_weights.exists(): raise FileNotFoundError(f"EfficientNet weights not found: {efficientnet_weights}") try: # Set device self.device = torch.device(device_str) # Load class labels from CSV # CSV format: id,Code,Species,Common # We use the 'Code' column (index 1) for class names self.classes = pd.read_csv(str(classes_csv)) # Initialize model with ImageNet pretrained weights num_classes = len(self.classes) self.model = EfficientNetV2M( num_classes=num_classes, pretrained_weights_path=efficientnet_weights, device_str=device_str, tune=False ) # Load fine-tuned checkpoint checkpoint = torch.load(str(self.model_path), map_location=self.device) self.model.load_state_dict(checkpoint['model']) self.model.to(self.device) self.model.eval() # Setup preprocessing (SDZWA animl-py framework uses 299x299) # Based on: https://github.com/conservationtechlab/animl-py self.preprocess = transforms.Compose([ transforms.Resize((299, 299)), transforms.ToTensor(), ]) except Exception as e: raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """ Crop image using SDZWA animl-py preprocessing. This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py framework approach with minimal buffering (0 pixels by default). Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py Args: image: PIL Image (full resolution) bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] Returns: Cropped PIL Image (not resized - resizing happens in get_classification) Raises: ValueError: If bbox is invalid """ buffer = 0 # SDZWA uses 0 pixel buffer width, height = image.size # Denormalize bbox coordinates bbox1, bbox2, bbox3, bbox4 = bbox left = width * bbox1 top = height * bbox2 right = width * (bbox1 + bbox3) bottom = height * (bbox2 + bbox4) # Apply buffer and clip to image boundaries left = max(0, int(left) - buffer) top = max(0, int(top) - buffer) right = min(width, int(right) + buffer) bottom = min(height, int(bottom) + buffer) # Validate crop dimensions if left >= right or top >= bottom: raise ValueError( f"Invalid bbox dimensions after cropping: " f"left={left}, top={top}, right={right}, bottom={bottom}" ) # Crop and return image_cropped = image.crop((left, top, right, bottom)) return image_cropped def get_classification(self, crop: Image.Image) -> list[list[str, float]]: """ Run PyTorch/EfficientNet classification on cropped image. Preprocessing follows SDZWA animl-py framework: - Resize to 299x299 (as per animl-py specifications) - Convert to tensor - No normalization Args: crop: Cropped PIL Image Returns: List of [class_name, confidence] lists for ALL classes. Example: [["cougar", 0.85], ["bobcat", 0.10], ["coyote", 0.02], ...] NOTE: Sorting by confidence is handled by classification_worker.py Raises: RuntimeError: If model not loaded or inference fails """ if self.model is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Preprocess image (resize and convert to tensor) input_tensor = self.preprocess(crop) input_batch = input_tensor.unsqueeze(0) # Add batch dimension input_batch = input_batch.to(self.device) # Run inference with torch.no_grad(): output = self.model(input_batch) # Apply softmax to get probabilities probabilities = F.softmax(output, dim=1) probabilities_np = probabilities.cpu().detach().numpy() confidence_scores = probabilities_np[0] # Build list of [class_name, confidence] pairs # Use 'Code' column (index 1) for class names classifications = [] for i in range(len(confidence_scores)): pred_class = self.classes.iloc[i].values[1] # 'Code' column pred_conf = float(confidence_scores[i]) classifications.append([pred_class, pred_conf]) return classifications except Exception as e: raise RuntimeError(f"PyTorch classification failed: {e}") from e def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to species names from CSV. Returns: Dict mapping class ID (1-indexed string) to species code Example: {"1": "badger", "2": "beaver", ..., "27": "weasel"} Raises: RuntimeError: If model not loaded """ if self.model is None or self.classes is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Create 1-indexed mapping of class IDs to names # Use 'Code' column (index 1) for class names class_names = {} for i in range(len(self.classes)): class_id_str = str(i + 1) # 1-indexed class_name = self.classes.iloc[i].values[1] # 'Code' column class_names[class_id_str] = class_name return class_names except Exception as e: raise RuntimeError(f"Failed to extract class names: {e}") from e