|
|
""" |
|
|
Inference script for NAM-ADS-v1 (Namibian Desert Species Classifier) |
|
|
|
|
|
This model identifies 30 species or higher level taxons present in the desert biome |
|
|
of the Skeleton Coast National Park, North Namibia. Trained on 850,000+ camera trap images. |
|
|
|
|
|
Model: Namibian Desert v1 |
|
|
Input: 640x640 RGB images |
|
|
Framework: PyTorch (YOLOv8 classification) |
|
|
Classes: 30 desert-adapted species and taxonomic groups |
|
|
Developer: Addax Data Science |
|
|
Citation: https://joss.theoj.org/papers/10.21105/joss.05581 |
|
|
License: CC BY-NC-SA 4.0 |
|
|
Info: https://addaxdatascience.com/projects/2023-01-dlc/ |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import pathlib |
|
|
import platform |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
from PIL import Image, ImageFile, ImageOps |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
plt = platform.system() |
|
|
if plt != 'Windows': |
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""YOLOv8 inference implementation for Namibian Desert species classifier.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files |
|
|
model_path: Path to namib_desert_v1.pt file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model: YOLO | None = None |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for YOLOv8 inference. |
|
|
|
|
|
Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. |
|
|
|
|
|
Returns: |
|
|
True if GPU available, False otherwise |
|
|
""" |
|
|
|
|
|
try: |
|
|
if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
|
|
return True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
return torch.cuda.is_available() |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
Load YOLOv8 classification model into memory. |
|
|
|
|
|
This function is called once during worker initialization. |
|
|
The model is stored in self.model and reused for all subsequent |
|
|
classification requests. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If model_path is invalid |
|
|
""" |
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
try: |
|
|
self.model = YOLO(str(self.model_path)) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Crop image using model-specific preprocessing. |
|
|
|
|
|
This cropping method was developed by Dan Morris for MegaDetector and is |
|
|
designed to: |
|
|
1. Square the bounding box (max of width/height) |
|
|
2. Add padding to prevent over-enlargement of small animals |
|
|
3. Center the detection within the crop |
|
|
4. Pad with black (0) to maintain square aspect ratio |
|
|
|
|
|
Args: |
|
|
image: PIL Image (full resolution) |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped and padded PIL Image ready for classification |
|
|
|
|
|
Raises: |
|
|
ValueError: If bbox is invalid (zero size) |
|
|
""" |
|
|
img_w, img_h = image.size |
|
|
|
|
|
|
|
|
xmin = int(bbox[0] * img_w) |
|
|
ymin = int(bbox[1] * img_h) |
|
|
box_w = int(bbox[2] * img_w) |
|
|
box_h = int(bbox[3] * img_h) |
|
|
|
|
|
|
|
|
box_size = max(box_w, box_h) |
|
|
|
|
|
|
|
|
box_size = self._pad_crop(box_size) |
|
|
|
|
|
|
|
|
xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w)) |
|
|
ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h)) |
|
|
|
|
|
|
|
|
box_w = min(img_w, box_size) |
|
|
box_h = min(img_h, box_size) |
|
|
|
|
|
if box_w == 0 or box_h == 0: |
|
|
raise ValueError(f"Invalid bbox size: {box_w}x{box_h}") |
|
|
|
|
|
|
|
|
crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h]) |
|
|
crop = ImageOps.pad(crop, size=(box_size, box_size), color=0) |
|
|
|
|
|
return crop |
|
|
|
|
|
def _pad_crop(self, box_size: int) -> int: |
|
|
""" |
|
|
Calculate padded crop size to prevent over-enlargement of small animals. |
|
|
|
|
|
YOLOv8 expects 224x224 input. This function ensures small detections aren't |
|
|
excessively upscaled while adding consistent padding to larger detections. |
|
|
|
|
|
Args: |
|
|
box_size: Original bounding box size (max of width/height) |
|
|
|
|
|
Returns: |
|
|
Padded box size |
|
|
""" |
|
|
input_size_network = 224 |
|
|
default_padding = 30 |
|
|
|
|
|
if box_size >= input_size_network: |
|
|
|
|
|
return box_size + default_padding |
|
|
else: |
|
|
|
|
|
diff_size = input_size_network - box_size |
|
|
if diff_size < default_padding: |
|
|
return box_size + default_padding |
|
|
else: |
|
|
return input_size_network |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run YOLOv8 classification on cropped image. |
|
|
|
|
|
Args: |
|
|
crop: Cropped and preprocessed PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes, in model order. |
|
|
Example: [["aardwolf", 0.00001], ["giraffe", 0.99985], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
results = self.model(crop, verbose=False) |
|
|
|
|
|
|
|
|
|
|
|
names_dict = results[0].names |
|
|
|
|
|
|
|
|
probs = results[0].probs.data.tolist() |
|
|
|
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for idx, class_name in names_dict.items(): |
|
|
confidence = probs[idx] |
|
|
classifications.append([class_name, confidence]) |
|
|
|
|
|
|
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"YOLOv8 classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to species names from YOLOv8 model. |
|
|
|
|
|
YOLOv8 stores class names in alphabetical order internally. This function |
|
|
extracts those names and creates a 1-indexed mapping for the JSON format. |
|
|
|
|
|
NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display. |
|
|
The class IDs here are YOLOv8's alphabetical indices (0-based) + 1. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to common name |
|
|
Example: {"1": "aardwolf", "2": "african wild cat", ..., "14": "giraffe", ...} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
yolo_names = self.model.names |
|
|
|
|
|
|
|
|
class_names = {} |
|
|
for idx, name in yolo_names.items(): |
|
|
class_id_str = str(idx + 1) |
|
|
class_names[class_id_str] = name |
|
|
|
|
|
return class_names |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to extract class names from model: {e}") from e |
|
|
|