Addax-Data-Science commited on
Commit
de23f27
·
verified ·
1 Parent(s): 7de371c

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +232 -0
inference.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for PAM-SDZWA-v1 (Peruvian Amazon Species Classifier)
3
+
4
+ This model classifies 53 species found in Peruvian Amazon rainforest habitats.
5
+ Developed by Mathias Tobler from the San Diego Zoo Wildlife Alliance Conservation
6
+ Technology Lab using their animl-py framework.
7
+
8
+ Model: Peru Amazon v0.86
9
+ Input: Variable size (extracted from model config)
10
+ Framework: TensorFlow/Keras (TensorFlow 1.x compatible)
11
+ Classes: 53 Amazonian species and taxonomic groups
12
+ Developer: San Diego Zoo Wildlife Alliance (Mathias Tobler)
13
+ License: MIT
14
+ Info: https://github.com/conservationtechlab
15
+
16
+ Author: Peter van Lunteren
17
+ Created: 2026-01-14
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ from pathlib import Path
24
+
25
+ import cv2
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ from PIL import Image, ImageFile
29
+ from tensorflow.keras.models import load_model
30
+
31
+ # Don't freak out over truncated images
32
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
33
+
34
+
35
+ class ModelInference:
36
+ """TensorFlow/Keras inference implementation for Peruvian Amazon species classifier."""
37
+
38
+ def __init__(self, model_dir: Path, model_path: Path):
39
+ """
40
+ Initialize with model paths.
41
+
42
+ Args:
43
+ model_dir: Directory containing model files and class labels
44
+ model_path: Path to Peru-Amazon_0.86.h5 file
45
+ """
46
+ self.model_dir = model_dir
47
+ self.model_path = model_path
48
+ self.model = None
49
+ self.img_size = None
50
+ self.class_map = {}
51
+ self.class_ids_sorted = []
52
+
53
+ def check_gpu(self) -> bool:
54
+ """
55
+ Check GPU availability for TensorFlow inference.
56
+
57
+ Returns:
58
+ True if GPU available, False otherwise
59
+ """
60
+ return len(tf.config.list_logical_devices('GPU')) > 0
61
+
62
+ def load_model(self) -> None:
63
+ """
64
+ Load TensorFlow/Keras model and class labels into memory.
65
+
66
+ This function is called once during worker initialization.
67
+ The model is stored in self.model and reused for all subsequent
68
+ classification requests.
69
+
70
+ Raises:
71
+ RuntimeError: If model loading fails
72
+ FileNotFoundError: If model_path or label file is invalid
73
+ """
74
+ if not self.model_path.exists():
75
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
76
+
77
+ try:
78
+ # Load Keras model
79
+ self.model = load_model(str(self.model_path))
80
+
81
+ # Extract input image size from model config
82
+ # Model expects square images (e.g., 299x299)
83
+ self.img_size = self.model.get_config()["layers"][0]["config"]["batch_input_shape"][1]
84
+
85
+ except Exception as e:
86
+ raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e
87
+
88
+ # Load class labels from text file
89
+ label_file = self.model_dir / "Peru-Amazon_0.86.txt"
90
+ if not label_file.exists():
91
+ raise FileNotFoundError(f"Class label file not found: {label_file}")
92
+
93
+ try:
94
+ with open(label_file, 'r') as file:
95
+ for line in file:
96
+ parts = line.strip().split('"')
97
+ if len(parts) >= 4:
98
+ identifier = parts[1].strip()
99
+ animal_name = parts[3].strip()
100
+ if identifier.isdigit():
101
+ self.class_map[str(identifier)] = str(animal_name)
102
+
103
+ # Create sorted list of class names (sorted by ID)
104
+ # This ensures consistent ordering for inference results
105
+ self.class_ids_sorted = sorted(self.class_map.values())
106
+
107
+ except Exception as e:
108
+ raise RuntimeError(f"Failed to load class labels from {label_file}: {e}") from e
109
+
110
+ def get_crop(
111
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
112
+ ) -> Image.Image:
113
+ """
114
+ Crop image using SDZWA animl-py preprocessing.
115
+
116
+ This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py
117
+ framework approach with minimal buffering (0 pixels by default).
118
+
119
+ Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py
120
+
121
+ Args:
122
+ image: PIL Image (full resolution)
123
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
124
+
125
+ Returns:
126
+ Cropped PIL Image (not resized - resizing happens in get_classification)
127
+
128
+ Raises:
129
+ ValueError: If bbox is invalid
130
+ """
131
+ buffer = 0 # SDZWA uses 0 pixel buffer
132
+ width, height = image.size
133
+
134
+ # Denormalize bbox coordinates
135
+ bbox1, bbox2, bbox3, bbox4 = bbox
136
+ left = width * bbox1
137
+ top = height * bbox2
138
+ right = width * (bbox1 + bbox3)
139
+ bottom = height * (bbox2 + bbox4)
140
+
141
+ # Apply buffer and clip to image boundaries
142
+ left = max(0, int(left) - buffer)
143
+ top = max(0, int(top) - buffer)
144
+ right = min(width, int(right) + buffer)
145
+ bottom = min(height, int(bottom) + buffer)
146
+
147
+ # Validate crop dimensions
148
+ if left >= right or top >= bottom:
149
+ raise ValueError(f"Invalid bbox dimensions after cropping: left={left}, top={top}, right={right}, bottom={bottom}")
150
+
151
+ # Crop and return
152
+ image_cropped = image.crop((left, top, right, bottom))
153
+ return image_cropped
154
+
155
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
156
+ """
157
+ Run TensorFlow/Keras classification on cropped image.
158
+
159
+ Preprocessing follows SDZWA animl-py framework:
160
+ - Resize to model input size (extracted from model config)
161
+ - Convert to numpy array
162
+ - No normalization or augmentation (except potential horizontal flip during training)
163
+
164
+ Args:
165
+ crop: Cropped PIL Image
166
+
167
+ Returns:
168
+ List of [class_name, confidence] lists for ALL classes, sorted by class ID.
169
+ Example: [["Black-headed squirrel monkey", 0.001], ["Brazilian rabbit", 0.002], ...]
170
+ NOTE: Sorting by confidence is handled by classification_worker.py
171
+
172
+ Raises:
173
+ RuntimeError: If model not loaded or inference fails
174
+ """
175
+ if self.model is None:
176
+ raise RuntimeError("Model not loaded - call load_model() first")
177
+
178
+ try:
179
+ # Convert PIL to numpy array
180
+ img = np.array(crop)
181
+
182
+ # Resize to model input size using OpenCV
183
+ img = cv2.resize(img, (self.img_size, self.img_size))
184
+
185
+ # Add batch dimension
186
+ img = np.expand_dims(img, axis=0)
187
+
188
+ # Run inference
189
+ # Note: According to animl-py, no special preprocessing is needed
190
+ # except for horizontal flip augmentation during training
191
+ pred = self.model.predict(img, verbose=0)[0]
192
+
193
+ # Build list of [class_name, confidence] pairs
194
+ # Use sorted class IDs to maintain consistent ordering
195
+ classifications = []
196
+ for i in range(len(pred)):
197
+ class_name = self.class_ids_sorted[i]
198
+ confidence = float(pred[i])
199
+ classifications.append([class_name, confidence])
200
+
201
+ return classifications
202
+
203
+ except Exception as e:
204
+ raise RuntimeError(f"Keras classification failed: {e}") from e
205
+
206
+ def get_class_names(self) -> dict[str, str]:
207
+ """
208
+ Get mapping of class IDs to species names.
209
+
210
+ Class IDs are 1-indexed and correspond to the sorted order of class names.
211
+
212
+ Returns:
213
+ Dict mapping class ID (1-indexed string) to species name
214
+ Example: {"1": "Black-headed squirrel monkey", "2": "Brazilian rabbit", ...}
215
+
216
+ Raises:
217
+ RuntimeError: If model not loaded
218
+ """
219
+ if self.model is None:
220
+ raise RuntimeError("Model not loaded - call load_model() first")
221
+
222
+ try:
223
+ # Create 1-indexed mapping of class IDs to names
224
+ class_names = {}
225
+ for i, class_name in enumerate(self.class_ids_sorted):
226
+ class_id_str = str(i + 1) # 1-indexed
227
+ class_names[class_id_str] = class_name
228
+
229
+ return class_names
230
+
231
+ except Exception as e:
232
+ raise RuntimeError(f"Failed to extract class names: {e}") from e