""" Inference script for KIR-HEX-v1 (Hex-Data OSI-Panthera Classification Model) This model uses a TorchScript JIT compiled model to classify wildlife detections. Developed by the Hex-Data team (https://www.hex-data.io/). Model: OSI-Panthera classification model Input: 316x316 RGB images Framework: PyTorch (TorchScript) Classes: Loaded from pickle file Author: Peter van Lunteren Created: 2026-01-14 """ from __future__ import annotations from pathlib import Path import pickle import platform import pathlib import torch from torchvision import transforms from PIL import Image, ImageFile # Allow loading truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix systems plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath class ModelInference: """ Inference class for the Hex-Data OSI-Panthera classification model. This model uses a TorchScript JIT compiled model with a simple preprocessing pipeline. Note that MPS (Apple Silicon GPU) is not supported for this model architecture, so it will always run on CPU or CUDA. """ def __init__(self, model_dir: Path, model_path: Path): """ Initialize the inference class. Args: model_dir: Path to the model directory model_path: Path to the model file (.pt) """ self.model_dir = model_dir self.model_path = model_path self.model = None self.device = None self.class_labels = None self.transform = None # Model-specific constants self.img_resize = 316 def check_gpu(self) -> bool: """ Check if GPU is available for inference. Note: This model architecture is not compatible with MPS (Apple Silicon), so we only check for CUDA availability. Returns: True if CUDA GPU is available, False otherwise """ return torch.cuda.is_available() def load_model(self, device_str: str = 'cpu') -> None: """ Load the TorchScript model and class labels. Args: device_str: Device to load the model on ('cpu' or 'cuda') Raises: FileNotFoundError: If model file or pickle file not found RuntimeError: If model loading fails """ # Set device self.device = torch.device(device_str) # Load TorchScript model if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") self.model = torch.jit.load(str(self.model_path), map_location=self.device) self.model.eval() # Load class labels from pickle file class_pickle_path = self.model_dir / 'classes_Fri_Sep__1_18_50_55_2023.pickle' if not class_pickle_path.exists(): raise FileNotFoundError(f"Class labels file not found: {class_pickle_path}") with open(class_pickle_path, "rb") as f: self.class_labels = pickle.load(f) # Define image transforms self.transform = transforms.Compose([ transforms.Resize([self.img_resize, self.img_resize]), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) ) ]) def get_crop(self, image: Image.Image, bbox_norm: list[float]) -> Image.Image: """ Crop detection from image using normalized bounding box. This implementation uses a simple direct crop without any padding or squaring. Args: image: Full PIL Image bbox_norm: Normalized bounding box [x_min, y_min, width, height] where all values are in range [0, 1] Returns: Cropped PIL Image """ img_w, img_h = image.size # Convert normalized coordinates to absolute pixel coordinates xmin = int(bbox_norm[0] * img_w) ymin = int(bbox_norm[1] * img_h) xmax = xmin + int(bbox_norm[2] * img_w) ymax = ymin + int(bbox_norm[3] * img_h) # Crop and return crop = image.crop(box=[xmin, ymin, xmax, ymax]) return crop def get_classification(self, crop: Image.Image) -> list[list[str, float]]: """ Run classification inference on a cropped detection. Args: crop: Cropped PIL Image containing the detection Returns: List of [class_name, confidence] pairs for ALL classes (unsorted). Example: [['lion', 0.92], ['leopard', 0.05], ['cheetah', 0.02], ...] """ # Preprocess image img_tensor = self.transform(crop) img_tensor = img_tensor.unsqueeze(0) # Add batch dimension img_tensor = img_tensor.to(self.device) # Run inference with torch.no_grad(): output = self.model(img_tensor) # Apply softmax to get probabilities softmax_output = torch.nn.functional.softmax(output, dim=1) # Format predictions as list of [class_name, confidence] predictions = [] for idx, prob in enumerate(softmax_output[0]): class_label = self.class_labels[idx] confidence = prob.item() predictions.append([class_label, confidence]) return predictions def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to class names. Returns: Dictionary mapping 1-indexed class ID strings to class names. Example: {'1': 'lion', '2': 'leopard', '3': 'cheetah', ...} """ if self.class_labels is None: raise RuntimeError("Model not loaded. Call load_model() first.") class_names = {} for idx, class_label in enumerate(self.class_labels): class_id_str = str(idx + 1) # 1-indexed class_names[class_id_str] = class_label return class_names