File size: 8,291 Bytes
de23f27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Inference script for PAM-SDZWA-v1 (Peruvian Amazon Species Classifier)

This model classifies 53 species found in Peruvian Amazon rainforest habitats.
Developed by Mathias Tobler from the San Diego Zoo Wildlife Alliance Conservation
Technology Lab using their animl-py framework.

Model: Peru Amazon v0.86
Input: Variable size (extracted from model config)
Framework: TensorFlow/Keras (TensorFlow 1.x compatible)
Classes: 53 Amazonian species and taxonomic groups
Developer: San Diego Zoo Wildlife Alliance (Mathias Tobler)
License: MIT
Info: https://github.com/conservationtechlab

Author: Peter van Lunteren
Created: 2026-01-14
"""

from __future__ import annotations

import os
from pathlib import Path

import cv2
import numpy as np
import tensorflow as tf
from PIL import Image, ImageFile
from tensorflow.keras.models import load_model

# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True


class ModelInference:
    """TensorFlow/Keras inference implementation for Peruvian Amazon species classifier."""

    def __init__(self, model_dir: Path, model_path: Path):
        """
        Initialize with model paths.

        Args:
            model_dir: Directory containing model files and class labels
            model_path: Path to Peru-Amazon_0.86.h5 file
        """
        self.model_dir = model_dir
        self.model_path = model_path
        self.model = None
        self.img_size = None
        self.class_map = {}
        self.class_ids_sorted = []

    def check_gpu(self) -> bool:
        """
        Check GPU availability for TensorFlow inference.

        Returns:
            True if GPU available, False otherwise
        """
        return len(tf.config.list_logical_devices('GPU')) > 0

    def load_model(self) -> None:
        """
        Load TensorFlow/Keras model and class labels into memory.

        This function is called once during worker initialization.
        The model is stored in self.model and reused for all subsequent
        classification requests.

        Raises:
            RuntimeError: If model loading fails
            FileNotFoundError: If model_path or label file is invalid
        """
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model file not found: {self.model_path}")

        try:
            # Load Keras model
            self.model = load_model(str(self.model_path))

            # Extract input image size from model config
            # Model expects square images (e.g., 299x299)
            self.img_size = self.model.get_config()["layers"][0]["config"]["batch_input_shape"][1]

        except Exception as e:
            raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e

        # Load class labels from text file
        label_file = self.model_dir / "Peru-Amazon_0.86.txt"
        if not label_file.exists():
            raise FileNotFoundError(f"Class label file not found: {label_file}")

        try:
            with open(label_file, 'r') as file:
                for line in file:
                    parts = line.strip().split('"')
                    if len(parts) >= 4:
                        identifier = parts[1].strip()
                        animal_name = parts[3].strip()
                        if identifier.isdigit():
                            self.class_map[str(identifier)] = str(animal_name)

            # Create sorted list of class names (sorted by ID)
            # This ensures consistent ordering for inference results
            self.class_ids_sorted = sorted(self.class_map.values())

        except Exception as e:
            raise RuntimeError(f"Failed to load class labels from {label_file}: {e}") from e

    def get_crop(
        self, image: Image.Image, bbox: tuple[float, float, float, float]
    ) -> Image.Image:
        """
        Crop image using SDZWA animl-py preprocessing.

        This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py
        framework approach with minimal buffering (0 pixels by default).

        Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py

        Args:
            image: PIL Image (full resolution)
            bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]

        Returns:
            Cropped PIL Image (not resized - resizing happens in get_classification)

        Raises:
            ValueError: If bbox is invalid
        """
        buffer = 0  # SDZWA uses 0 pixel buffer
        width, height = image.size

        # Denormalize bbox coordinates
        bbox1, bbox2, bbox3, bbox4 = bbox
        left = width * bbox1
        top = height * bbox2
        right = width * (bbox1 + bbox3)
        bottom = height * (bbox2 + bbox4)

        # Apply buffer and clip to image boundaries
        left = max(0, int(left) - buffer)
        top = max(0, int(top) - buffer)
        right = min(width, int(right) + buffer)
        bottom = min(height, int(bottom) + buffer)

        # Validate crop dimensions
        if left >= right or top >= bottom:
            raise ValueError(f"Invalid bbox dimensions after cropping: left={left}, top={top}, right={right}, bottom={bottom}")

        # Crop and return
        image_cropped = image.crop((left, top, right, bottom))
        return image_cropped

    def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
        """
        Run TensorFlow/Keras classification on cropped image.

        Preprocessing follows SDZWA animl-py framework:
        - Resize to model input size (extracted from model config)
        - Convert to numpy array
        - No normalization or augmentation (except potential horizontal flip during training)

        Args:
            crop: Cropped PIL Image

        Returns:
            List of [class_name, confidence] lists for ALL classes, sorted by class ID.
            Example: [["Black-headed squirrel monkey", 0.001], ["Brazilian rabbit", 0.002], ...]
            NOTE: Sorting by confidence is handled by classification_worker.py

        Raises:
            RuntimeError: If model not loaded or inference fails
        """
        if self.model is None:
            raise RuntimeError("Model not loaded - call load_model() first")

        try:
            # Convert PIL to numpy array
            img = np.array(crop)

            # Resize to model input size using OpenCV
            img = cv2.resize(img, (self.img_size, self.img_size))

            # Add batch dimension
            img = np.expand_dims(img, axis=0)

            # Run inference
            # Note: According to animl-py, no special preprocessing is needed
            # except for horizontal flip augmentation during training
            pred = self.model.predict(img, verbose=0)[0]

            # Build list of [class_name, confidence] pairs
            # Use sorted class IDs to maintain consistent ordering
            classifications = []
            for i in range(len(pred)):
                class_name = self.class_ids_sorted[i]
                confidence = float(pred[i])
                classifications.append([class_name, confidence])

            return classifications

        except Exception as e:
            raise RuntimeError(f"Keras classification failed: {e}") from e

    def get_class_names(self) -> dict[str, str]:
        """
        Get mapping of class IDs to species names.

        Class IDs are 1-indexed and correspond to the sorted order of class names.

        Returns:
            Dict mapping class ID (1-indexed string) to species name
            Example: {"1": "Black-headed squirrel monkey", "2": "Brazilian rabbit", ...}

        Raises:
            RuntimeError: If model not loaded
        """
        if self.model is None:
            raise RuntimeError("Model not loaded - call load_model() first")

        try:
            # Create 1-indexed mapping of class IDs to names
            class_names = {}
            for i, class_name in enumerate(self.class_ids_sorted):
                class_id_str = str(i + 1)  # 1-indexed
                class_names[class_id_str] = class_name

            return class_names

        except Exception as e:
            raise RuntimeError(f"Failed to extract class names: {e}") from e