File size: 6,021 Bytes
aebd820 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
"""
Inference script for KIR-HEX-v1 (Hex-Data OSI-Panthera Classification Model)
This model uses a TorchScript JIT compiled model to classify wildlife detections.
Developed by the Hex-Data team (https://www.hex-data.io/).
Model: OSI-Panthera classification model
Input: 316x316 RGB images
Framework: PyTorch (TorchScript)
Classes: Loaded from pickle file
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
from pathlib import Path
import pickle
import platform
import pathlib
import torch
from torchvision import transforms
from PIL import Image, ImageFile
# Allow loading truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix systems
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class ModelInference:
"""
Inference class for the Hex-Data OSI-Panthera classification model.
This model uses a TorchScript JIT compiled model with a simple preprocessing
pipeline. Note that MPS (Apple Silicon GPU) is not supported for this model
architecture, so it will always run on CPU or CUDA.
"""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize the inference class.
Args:
model_dir: Path to the model directory
model_path: Path to the model file (.pt)
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.device = None
self.class_labels = None
self.transform = None
# Model-specific constants
self.img_resize = 316
def check_gpu(self) -> bool:
"""
Check if GPU is available for inference.
Note: This model architecture is not compatible with MPS (Apple Silicon),
so we only check for CUDA availability.
Returns:
True if CUDA GPU is available, False otherwise
"""
return torch.cuda.is_available()
def load_model(self, device_str: str = 'cpu') -> None:
"""
Load the TorchScript model and class labels.
Args:
device_str: Device to load the model on ('cpu' or 'cuda')
Raises:
FileNotFoundError: If model file or pickle file not found
RuntimeError: If model loading fails
"""
# Set device
self.device = torch.device(device_str)
# Load TorchScript model
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
self.model = torch.jit.load(str(self.model_path), map_location=self.device)
self.model.eval()
# Load class labels from pickle file
class_pickle_path = self.model_dir / 'classes_Fri_Sep__1_18_50_55_2023.pickle'
if not class_pickle_path.exists():
raise FileNotFoundError(f"Class labels file not found: {class_pickle_path}")
with open(class_pickle_path, "rb") as f:
self.class_labels = pickle.load(f)
# Define image transforms
self.transform = transforms.Compose([
transforms.Resize([self.img_resize, self.img_resize]),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)
)
])
def get_crop(self, image: Image.Image, bbox_norm: list[float]) -> Image.Image:
"""
Crop detection from image using normalized bounding box.
This implementation uses a simple direct crop without any padding or squaring.
Args:
image: Full PIL Image
bbox_norm: Normalized bounding box [x_min, y_min, width, height]
where all values are in range [0, 1]
Returns:
Cropped PIL Image
"""
img_w, img_h = image.size
# Convert normalized coordinates to absolute pixel coordinates
xmin = int(bbox_norm[0] * img_w)
ymin = int(bbox_norm[1] * img_h)
xmax = xmin + int(bbox_norm[2] * img_w)
ymax = ymin + int(bbox_norm[3] * img_h)
# Crop and return
crop = image.crop(box=[xmin, ymin, xmax, ymax])
return crop
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run classification inference on a cropped detection.
Args:
crop: Cropped PIL Image containing the detection
Returns:
List of [class_name, confidence] pairs for ALL classes (unsorted).
Example: [['lion', 0.92], ['leopard', 0.05], ['cheetah', 0.02], ...]
"""
# Preprocess image
img_tensor = self.transform(crop)
img_tensor = img_tensor.unsqueeze(0) # Add batch dimension
img_tensor = img_tensor.to(self.device)
# Run inference
with torch.no_grad():
output = self.model(img_tensor)
# Apply softmax to get probabilities
softmax_output = torch.nn.functional.softmax(output, dim=1)
# Format predictions as list of [class_name, confidence]
predictions = []
for idx, prob in enumerate(softmax_output[0]):
class_label = self.class_labels[idx]
confidence = prob.item()
predictions.append([class_label, confidence])
return predictions
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to class names.
Returns:
Dictionary mapping 1-indexed class ID strings to class names.
Example: {'1': 'lion', '2': 'leopard', '3': 'cheetah', ...}
"""
if self.class_labels is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
class_names = {}
for idx, class_label in enumerate(self.class_labels):
class_id_str = str(idx + 1) # 1-indexed
class_names[class_id_str] = class_label
return class_names
|