KIR-HEX-v1 / inference.py
Addax-Data-Science's picture
Upload inference.py
aebd820 verified
"""
Inference script for KIR-HEX-v1 (Hex-Data OSI-Panthera Classification Model)
This model uses a TorchScript JIT compiled model to classify wildlife detections.
Developed by the Hex-Data team (https://www.hex-data.io/).
Model: OSI-Panthera classification model
Input: 316x316 RGB images
Framework: PyTorch (TorchScript)
Classes: Loaded from pickle file
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
from pathlib import Path
import pickle
import platform
import pathlib
import torch
from torchvision import transforms
from PIL import Image, ImageFile
# Allow loading truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix systems
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class ModelInference:
"""
Inference class for the Hex-Data OSI-Panthera classification model.
This model uses a TorchScript JIT compiled model with a simple preprocessing
pipeline. Note that MPS (Apple Silicon GPU) is not supported for this model
architecture, so it will always run on CPU or CUDA.
"""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize the inference class.
Args:
model_dir: Path to the model directory
model_path: Path to the model file (.pt)
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.device = None
self.class_labels = None
self.transform = None
# Model-specific constants
self.img_resize = 316
def check_gpu(self) -> bool:
"""
Check if GPU is available for inference.
Note: This model architecture is not compatible with MPS (Apple Silicon),
so we only check for CUDA availability.
Returns:
True if CUDA GPU is available, False otherwise
"""
return torch.cuda.is_available()
def load_model(self, device_str: str = 'cpu') -> None:
"""
Load the TorchScript model and class labels.
Args:
device_str: Device to load the model on ('cpu' or 'cuda')
Raises:
FileNotFoundError: If model file or pickle file not found
RuntimeError: If model loading fails
"""
# Set device
self.device = torch.device(device_str)
# Load TorchScript model
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
self.model = torch.jit.load(str(self.model_path), map_location=self.device)
self.model.eval()
# Load class labels from pickle file
class_pickle_path = self.model_dir / 'classes_Fri_Sep__1_18_50_55_2023.pickle'
if not class_pickle_path.exists():
raise FileNotFoundError(f"Class labels file not found: {class_pickle_path}")
with open(class_pickle_path, "rb") as f:
self.class_labels = pickle.load(f)
# Define image transforms
self.transform = transforms.Compose([
transforms.Resize([self.img_resize, self.img_resize]),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
def get_crop(self, image: Image.Image, bbox_norm: list[float]) -> Image.Image:
"""
Crop detection from image using normalized bounding box.
This implementation uses a simple direct crop without any padding or squaring.
Args:
image: Full PIL Image
bbox_norm: Normalized bounding box [x_min, y_min, width, height]
where all values are in range [0, 1]
Returns:
Cropped PIL Image
"""
img_w, img_h = image.size
# Convert normalized coordinates to absolute pixel coordinates
xmin = int(bbox_norm[0] * img_w)
ymin = int(bbox_norm[1] * img_h)
xmax = xmin + int(bbox_norm[2] * img_w)
ymax = ymin + int(bbox_norm[3] * img_h)
# Crop and return
crop = image.crop(box=[xmin, ymin, xmax, ymax])
return crop
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run classification inference on a cropped detection.
Args:
crop: Cropped PIL Image containing the detection
Returns:
List of [class_name, confidence] pairs for ALL classes (unsorted).
Example: [['lion', 0.92], ['leopard', 0.05], ['cheetah', 0.02], ...]
"""
# Preprocess image
img_tensor = self.transform(crop)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
img_tensor = img_tensor.to(self.device)
# Run inference
with torch.no_grad():
output = self.model(img_tensor)
# Apply softmax to get probabilities
softmax_output = torch.nn.functional.softmax(output, dim=1)
# Format predictions as list of [class_name, confidence]
predictions = []
for idx, prob in enumerate(softmax_output[0]):
class_label = self.class_labels[idx]
confidence = prob.item()
predictions.append([class_label, confidence])
return predictions
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to class names.
Returns:
Dictionary mapping 1-indexed class ID strings to class names.
Example: {'1': 'lion', '2': 'leopard', '3': 'cheetah', ...}
"""
if self.class_labels is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
class_names = {}
for idx, class_label in enumerate(self.class_labels):
class_id_str = str(idx + 1) # 1-indexed
class_names[class_id_str] = class_label
return class_names