""" Inference script for NZI-ADS-v1 (New Zealand Invasive Species Classifier) This model identifies 17 species or higher-level taxons present in New Zealand, designed to expedite monitoring of invasive species (deer, possum, pig, cat, rodent, mustelid). Trained on ~2 million camera trap images from diverse habitats across the country using multiple camera brands. Overall validation accuracy: 98%, test set performance: 95%. Model: New Zealand Invasives v1 Input: 640x640 RGB images Framework: PyTorch (YOLOv8 classification) Classes: 17 species and taxonomic groups (focus on invasives) Developer: Addax Data Science Owner: New Zealand Department of Conservation License: CC BY-NC-SA 4.0 Info: https://www.doc.govt.nz/ Author: Peter van Lunteren Created: 2026-01-14 """ from __future__ import annotations import pathlib import platform from pathlib import Path import torch from PIL import Image, ImageFile, ImageOps from ultralytics import YOLO # Don't freak out over truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath class ModelInference: """YOLOv8 inference implementation for New Zealand invasive species classifier.""" def __init__(self, model_dir: Path, model_path: Path): """ Initialize with model paths. Args: model_dir: Directory containing model files model_path: Path to new_zealand_v1.pt file """ self.model_dir = model_dir self.model_path = model_path self.model: YOLO | None = None def check_gpu(self) -> bool: """ Check GPU availability for YOLOv8 inference. Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. Returns: True if GPU available, False otherwise """ # Check Apple MPS (Apple Silicon) try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass # Check CUDA (NVIDIA) return torch.cuda.is_available() def load_model(self) -> None: """ Load YOLOv8 classification model into memory. This function is called once during worker initialization. The model is stored in self.model and reused for all subsequent classification requests. Raises: RuntimeError: If model loading fails FileNotFoundError: If model_path is invalid """ if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") try: self.model = YOLO(str(self.model_path)) except Exception as e: raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """ Crop image using model-specific preprocessing. This cropping method was developed by Dan Morris for MegaDetector and is designed to: 1. Square the bounding box (max of width/height) 2. Add padding to prevent over-enlargement of small animals 3. Center the detection within the crop 4. Pad with black (0) to maintain square aspect ratio Args: image: PIL Image (full resolution) bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] Returns: Cropped and padded PIL Image ready for classification Raises: ValueError: If bbox is invalid (zero size) """ img_w, img_h = image.size # Denormalize bbox coordinates xmin = int(bbox[0] * img_w) ymin = int(bbox[1] * img_h) box_w = int(bbox[2] * img_w) box_h = int(bbox[3] * img_h) # Square the box (use max dimension) box_size = max(box_w, box_h) # Add padding (prevents over-enlargement of small animals) box_size = self._pad_crop(box_size) # Center the detection within the squared crop xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w)) ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h)) # Clip to image boundaries box_w = min(img_w, box_size) box_h = min(img_h, box_size) if box_w == 0 or box_h == 0: raise ValueError(f"Invalid bbox size: {box_w}x{box_h}") # Crop and pad to square crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h]) crop = ImageOps.pad(crop, size=(box_size, box_size), color=0) return crop def _pad_crop(self, box_size: int) -> int: """ Calculate padded crop size to prevent over-enlargement of small animals. YOLOv8 expects 224x224 input. This function ensures small detections aren't excessively upscaled while adding consistent padding to larger detections. Args: box_size: Original bounding box size (max of width/height) Returns: Padded box size """ input_size_network = 224 default_padding = 30 if box_size >= input_size_network: # Large detection: add default padding return box_size + default_padding else: # Small detection: ensure minimum size without excessive enlargement diff_size = input_size_network - box_size if diff_size < default_padding: return box_size + default_padding else: return input_size_network def get_classification(self, crop: Image.Image) -> list[list[str, float]]: """ Run YOLOv8 classification on cropped image. Args: crop: Cropped and preprocessed PIL Image Returns: List of [class_name, confidence] lists for ALL classes, in model order. Example: [["cat", 0.00001], ["possum", 0.99985], ...] NOTE: Sorting by confidence is handled by classification_worker.py Raises: RuntimeError: If model not loaded or inference fails """ if self.model is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Run YOLOv8 classification (verbose=False suppresses progress bar) results = self.model(crop, verbose=False) # Extract class names dict (YOLOv8 uses alphabetical order) # Example: {0: "bird", 1: "cat", ..., 10: "possum", ...} names_dict = results[0].names # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...] probs = results[0].probs.data.tolist() # Build list of [class_name, confidence] pairs (as lists, not tuples!) # Return YOLOv8's class names (which will be mapped to taxonomy IDs later) classifications = [] for idx, class_name in names_dict.items(): confidence = probs[idx] classifications.append([class_name, confidence]) # NOTE: Sorting by confidence is handled by classification_worker.py # Model developers don't need to sort - just return all class predictions return classifications except Exception as e: raise RuntimeError(f"YOLOv8 classification failed: {e}") from e def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to species names from YOLOv8 model. YOLOv8 stores class names in alphabetical order internally. This function extracts those names and creates a 1-indexed mapping for the JSON format. NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display. The class IDs here are YOLOv8's alphabetical indices (0-based) + 1. Returns: Dict mapping class ID (1-indexed string) to common name Example: {"1": "bird", "2": "cat", ..., "11": "possum", ...} Raises: RuntimeError: If model not loaded """ if self.model is None: raise RuntimeError("Model not loaded - call load_model() first") try: # YOLOv8 names dict (alphabetical order): {0: "bird", 1: "cat", ...} yolo_names = self.model.names # Convert to 1-indexed dict for JSON compatibility class_names = {} for idx, name in yolo_names.items(): class_id_str = str(idx + 1) # 1-indexed class_names[class_id_str] = name return class_names except Exception as e: raise RuntimeError(f"Failed to extract class names from model: {e}") from e