|
|
""" |
|
|
Inference script for KIR-HEX-v1 (Hex-Data OSI-Panthera Classification Model) |
|
|
|
|
|
This model uses a TorchScript JIT compiled model to classify wildlife detections. |
|
|
Developed by the Hex-Data team (https://www.hex-data.io/). |
|
|
|
|
|
Model: OSI-Panthera classification model |
|
|
Input: 316x316 RGB images |
|
|
Framework: PyTorch (TorchScript) |
|
|
Classes: Loaded from pickle file |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
import pickle |
|
|
import platform |
|
|
import pathlib |
|
|
|
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image, ImageFile |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
plt = platform.system() |
|
|
if plt != 'Windows': |
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
""" |
|
|
Inference class for the Hex-Data OSI-Panthera classification model. |
|
|
|
|
|
This model uses a TorchScript JIT compiled model with a simple preprocessing |
|
|
pipeline. Note that MPS (Apple Silicon GPU) is not supported for this model |
|
|
architecture, so it will always run on CPU or CUDA. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize the inference class. |
|
|
|
|
|
Args: |
|
|
model_dir: Path to the model directory |
|
|
model_path: Path to the model file (.pt) |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model = None |
|
|
self.device = None |
|
|
self.class_labels = None |
|
|
self.transform = None |
|
|
|
|
|
|
|
|
self.img_resize = 316 |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check if GPU is available for inference. |
|
|
|
|
|
Note: This model architecture is not compatible with MPS (Apple Silicon), |
|
|
so we only check for CUDA availability. |
|
|
|
|
|
Returns: |
|
|
True if CUDA GPU is available, False otherwise |
|
|
""" |
|
|
return torch.cuda.is_available() |
|
|
|
|
|
def load_model(self, device_str: str = 'cpu') -> None: |
|
|
""" |
|
|
Load the TorchScript model and class labels. |
|
|
|
|
|
Args: |
|
|
device_str: Device to load the model on ('cpu' or 'cuda') |
|
|
|
|
|
Raises: |
|
|
FileNotFoundError: If model file or pickle file not found |
|
|
RuntimeError: If model loading fails |
|
|
""" |
|
|
|
|
|
self.device = torch.device(device_str) |
|
|
|
|
|
|
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
self.model = torch.jit.load(str(self.model_path), map_location=self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
class_pickle_path = self.model_dir / 'classes_Fri_Sep__1_18_50_55_2023.pickle' |
|
|
if not class_pickle_path.exists(): |
|
|
raise FileNotFoundError(f"Class labels file not found: {class_pickle_path}") |
|
|
|
|
|
with open(class_pickle_path, "rb") as f: |
|
|
self.class_labels = pickle.load(f) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize([self.img_resize, self.img_resize]), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize( |
|
|
mean=(0.485, 0.456, 0.406), |
|
|
std=(0.229, 0.224, 0.225) |
|
|
) |
|
|
]) |
|
|
|
|
|
def get_crop(self, image: Image.Image, bbox_norm: list[float]) -> Image.Image: |
|
|
""" |
|
|
Crop detection from image using normalized bounding box. |
|
|
|
|
|
This implementation uses a simple direct crop without any padding or squaring. |
|
|
|
|
|
Args: |
|
|
image: Full PIL Image |
|
|
bbox_norm: Normalized bounding box [x_min, y_min, width, height] |
|
|
where all values are in range [0, 1] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image |
|
|
""" |
|
|
img_w, img_h = image.size |
|
|
|
|
|
|
|
|
xmin = int(bbox_norm[0] * img_w) |
|
|
ymin = int(bbox_norm[1] * img_h) |
|
|
xmax = xmin + int(bbox_norm[2] * img_w) |
|
|
ymax = ymin + int(bbox_norm[3] * img_h) |
|
|
|
|
|
|
|
|
crop = image.crop(box=[xmin, ymin, xmax, ymax]) |
|
|
return crop |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run classification inference on a cropped detection. |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image containing the detection |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] pairs for ALL classes (unsorted). |
|
|
Example: [['lion', 0.92], ['leopard', 0.05], ['cheetah', 0.02], ...] |
|
|
""" |
|
|
|
|
|
img_tensor = self.transform(crop) |
|
|
img_tensor = img_tensor.unsqueeze(0) |
|
|
img_tensor = img_tensor.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(img_tensor) |
|
|
|
|
|
|
|
|
softmax_output = torch.nn.functional.softmax(output, dim=1) |
|
|
|
|
|
|
|
|
predictions = [] |
|
|
for idx, prob in enumerate(softmax_output[0]): |
|
|
class_label = self.class_labels[idx] |
|
|
confidence = prob.item() |
|
|
predictions.append([class_label, confidence]) |
|
|
|
|
|
return predictions |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to class names. |
|
|
|
|
|
Returns: |
|
|
Dictionary mapping 1-indexed class ID strings to class names. |
|
|
Example: {'1': 'lion', '2': 'leopard', '3': 'cheetah', ...} |
|
|
""" |
|
|
if self.class_labels is None: |
|
|
raise RuntimeError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
class_names = {} |
|
|
for idx, class_label in enumerate(self.class_labels): |
|
|
class_id_str = str(idx + 1) |
|
|
class_names[class_id_str] = class_label |
|
|
|
|
|
return class_names |
|
|
|