File size: 8,826 Bytes
c9d39e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
Inference script for TKM-ADS-v1 (Turkmenistan Species Classifier)

This model identifies 14 species or higher-level taxons present in Southern Turkmenistan.
Trained on ~1 million camera trap images achieving 95% validation accuracy, 93% precision,
and 94% recall. Note: Accuracy not tested on out-of-sample local dataset as local images
were not available.

Model: Turkmenistan v1
Input: 640x640 RGB images
Framework: PyTorch (YOLOv8 classification)
Classes: 14 species and taxonomic groups
Developer: Addax Data Science
Citation: https://joss.theoj.org/papers/10.21105/joss.05581
License: CC BY-NC-SA 4.0
Info: https://addaxdatascience.com/

Author: Peter van Lunteren
Created: 2026-01-14
"""

from __future__ import annotations

import pathlib
import platform
from pathlib import Path

import torch
from PIL import Image, ImageFile, ImageOps
from ultralytics import YOLO

# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Make sure Windows-trained models work on Unix
plt = platform.system()
if plt != 'Windows':
    pathlib.WindowsPath = pathlib.PosixPath


class ModelInference:
    """YOLOv8 inference implementation for Turkmenistan species classifier."""

    def __init__(self, model_dir: Path, model_path: Path):
        """
        Initialize with model paths.

        Args:
            model_dir: Directory containing model files
            model_path: Path to tkm_v1.pt file
        """
        self.model_dir = model_dir
        self.model_path = model_path
        self.model: YOLO | None = None

    def check_gpu(self) -> bool:
        """
        Check GPU availability for YOLOv8 inference.

        Checks both Apple Metal Performance Shaders (MPS) and CUDA availability.

        Returns:
            True if GPU available, False otherwise
        """
        # Check Apple MPS (Apple Silicon)
        try:
            if torch.backends.mps.is_built() and torch.backends.mps.is_available():
                return True
        except Exception:
            pass

        # Check CUDA (NVIDIA)
        return torch.cuda.is_available()

    def load_model(self) -> None:
        """
        Load YOLOv8 classification model into memory.

        This function is called once during worker initialization.
        The model is stored in self.model and reused for all subsequent
        classification requests.

        Raises:
            RuntimeError: If model loading fails
            FileNotFoundError: If model_path is invalid
        """
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model file not found: {self.model_path}")

        try:
            self.model = YOLO(str(self.model_path))
        except Exception as e:
            raise RuntimeError(f"Failed to load YOLOv8 model from {self.model_path}: {e}") from e

    def get_crop(
        self, image: Image.Image, bbox: tuple[float, float, float, float]
    ) -> Image.Image:
        """
        Crop image using model-specific preprocessing.

        This cropping method was developed by Dan Morris for MegaDetector and is
        designed to:
        1. Square the bounding box (max of width/height)
        2. Add padding to prevent over-enlargement of small animals
        3. Center the detection within the crop
        4. Pad with black (0) to maintain square aspect ratio

        Args:
            image: PIL Image (full resolution)
            bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]

        Returns:
            Cropped and padded PIL Image ready for classification

        Raises:
            ValueError: If bbox is invalid (zero size)
        """
        img_w, img_h = image.size

        # Denormalize bbox coordinates
        xmin = int(bbox[0] * img_w)
        ymin = int(bbox[1] * img_h)
        box_w = int(bbox[2] * img_w)
        box_h = int(bbox[3] * img_h)

        # Square the box (use max dimension)
        box_size = max(box_w, box_h)

        # Add padding (prevents over-enlargement of small animals)
        box_size = self._pad_crop(box_size)

        # Center the detection within the squared crop
        xmin = max(0, min(xmin - int((box_size - box_w) / 2), img_w - box_w))
        ymin = max(0, min(ymin - int((box_size - box_h) / 2), img_h - box_h))

        # Clip to image boundaries
        box_w = min(img_w, box_size)
        box_h = min(img_h, box_size)

        if box_w == 0 or box_h == 0:
            raise ValueError(f"Invalid bbox size: {box_w}x{box_h}")

        # Crop and pad to square
        crop = image.crop(box=[xmin, ymin, xmin + box_w, ymin + box_h])
        crop = ImageOps.pad(crop, size=(box_size, box_size), color=0)

        return crop

    def _pad_crop(self, box_size: int) -> int:
        """
        Calculate padded crop size to prevent over-enlargement of small animals.

        YOLOv8 expects 224x224 input. This function ensures small detections aren't
        excessively upscaled while adding consistent padding to larger detections.

        Args:
            box_size: Original bounding box size (max of width/height)

        Returns:
            Padded box size
        """
        input_size_network = 224
        default_padding = 30

        if box_size >= input_size_network:
            # Large detection: add default padding
            return box_size + default_padding
        else:
            # Small detection: ensure minimum size without excessive enlargement
            diff_size = input_size_network - box_size
            if diff_size < default_padding:
                return box_size + default_padding
            else:
                return input_size_network

    def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
        """
        Run YOLOv8 classification on cropped image.

        Args:
            crop: Cropped and preprocessed PIL Image

        Returns:
            List of [class_name, confidence] lists for ALL classes, in model order.
            Example: [["goitered gazelle", 0.92], ["urial", 0.05], ["wolf", 0.02], ...]
            NOTE: Sorting by confidence is handled by classification_worker.py

        Raises:
            RuntimeError: If model not loaded or inference fails
        """
        if self.model is None:
            raise RuntimeError("Model not loaded - call load_model() first")

        try:
            # Run YOLOv8 classification (verbose=False suppresses progress bar)
            results = self.model(crop, verbose=False)

            # Extract class names dict (YOLOv8 uses alphabetical order)
            # Example: {0: "bird", 1: "goitered gazelle", ..., 13: "wolf"}
            names_dict = results[0].names

            # Extract probabilities: [0.0001, 0.0002, ..., 0.9998, ...]
            probs = results[0].probs.data.tolist()

            # Build list of [class_name, confidence] pairs (as lists, not tuples!)
            # Return YOLOv8's class names (which will be mapped to taxonomy IDs later)
            classifications = []
            for idx, class_name in names_dict.items():
                confidence = probs[idx]
                classifications.append([class_name, confidence])

            # NOTE: Sorting by confidence is handled by classification_worker.py
            # Model developers don't need to sort - just return all class predictions
            return classifications

        except Exception as e:
            raise RuntimeError(f"YOLOv8 classification failed: {e}") from e

    def get_class_names(self) -> dict[str, str]:
        """
        Get mapping of class IDs to species names from YOLOv8 model.

        YOLOv8 stores class names in alphabetical order internally. This function
        extracts those names and creates a 1-indexed mapping for the JSON format.

        NOTE: taxonomy.csv is NOT used here - it's only for UI taxonomy tree display.
        The class IDs here are YOLOv8's alphabetical indices (0-based) + 1.

        Returns:
            Dict mapping class ID (1-indexed string) to common name
            Example: {"1": "bird", "2": "goitered gazelle", ..., "14": "wolf"}

        Raises:
            RuntimeError: If model not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded - call load_model() first")

        try:
            # YOLOv8 names dict (alphabetical order): {0: "bird", 1: "goitered gazelle", ...}
            yolo_names = self.model.names

            # Convert to 1-indexed dict for JSON compatibility
            class_names = {}
            for idx, name in yolo_names.items():
                class_id_str = str(idx + 1)  # 1-indexed
                class_names[class_id_str] = name

            return class_names

        except Exception as e:
            raise RuntimeError(f"Failed to extract class names from model: {e}") from e