Addax-Data-Science commited on
Commit
462b9b8
·
verified ·
1 Parent(s): ddf1426

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +292 -0
inference.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TAS-BB-v1 MEWC-Keras Classification Model - Custom Inference Script
3
+
4
+ This script provides model-specific inference code for the Tasmania
5
+ species classifier (96 classes). It follows the AddaxAI-WebUI interface contract
6
+ for custom classification models.
7
+
8
+ Model: Tasmania MEWC Ensemble (tas_ens_mewc.keras)
9
+ Framework: Keras 3 with JAX backend (TensorFlow compatible)
10
+ Classes: 96 classes (Tasmanian terrestrial mammals and birds)
11
+ Training data: 2.5+ million images
12
+
13
+ MEWC - Mega Efficient Wildlife Classifier - University of Tasmania
14
+ Original source: streamlit-AddaxAI/classification/model_types/mewc-keras/classify_detections.py
15
+ Reference: https://github.com/zaandahl/mewc
16
+ Adapted by: Claude Code on 2026-01-11
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import os
22
+ from pathlib import Path
23
+
24
+ import cv2
25
+ import numpy as np
26
+ import tensorflow as tf
27
+ import yaml
28
+ from keras import saving
29
+ from PIL import Image, ImageFile
30
+
31
+ # Set Keras backend to JAX (as per original MEWC code)
32
+ os.environ["KERAS_BACKEND"] = "jax"
33
+
34
+ # Module-level variables (injected by AddaxAI framework)
35
+ MODEL_DIR: Path | None = None # Set by CustomInferenceLoader
36
+ MODEL_PATH: Path | None = None # Set by CustomInferenceLoader
37
+
38
+ # Module-level model instance (loaded once at startup)
39
+ animal_model = None
40
+ img_size = 384 # MEWC uses 384x384 images
41
+
42
+ # Class mapping variables
43
+ class_map: dict[str, str] | None = None
44
+ class_ids: list[str] | None = None
45
+
46
+ # Don't freak out over truncated images
47
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
48
+
49
+
50
+ def check_gpu() -> bool:
51
+ """
52
+ Check GPU availability for TensorFlow/Keras inference.
53
+
54
+ TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA.
55
+
56
+ Returns:
57
+ True if GPU available, False otherwise
58
+ """
59
+ try:
60
+ gpus = tf.config.list_logical_devices('GPU')
61
+ return len(gpus) > 0
62
+ except Exception:
63
+ return False
64
+
65
+
66
+ def load_model() -> None:
67
+ """
68
+ Load Keras classification model into memory.
69
+
70
+ This function is called once during worker initialization.
71
+ The model is stored in the global `animal_model` variable and reused
72
+ for all subsequent classification requests.
73
+
74
+ Also loads the class_list.yaml file which maps class indices to species names.
75
+
76
+ Raises:
77
+ RuntimeError: If model loading fails
78
+ FileNotFoundError: If MODEL_PATH or class_list.yaml is invalid
79
+ """
80
+ global animal_model, class_map, class_ids, MODEL_PATH, MODEL_DIR
81
+
82
+ if MODEL_PATH is None:
83
+ raise RuntimeError("MODEL_PATH not set - must be injected by framework")
84
+
85
+ if MODEL_DIR is None:
86
+ raise RuntimeError("MODEL_DIR not set - must be injected by framework")
87
+
88
+ if not MODEL_PATH.exists():
89
+ raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
90
+
91
+ # Load the Keras model (without compilation for inference only)
92
+ try:
93
+ animal_model = saving.load_model(str(MODEL_PATH), compile=False)
94
+ except Exception as e:
95
+ raise RuntimeError(f"Failed to load Keras model from {MODEL_PATH}: {e}") from e
96
+
97
+ # Load class_list.yaml
98
+ class_list_path = MODEL_DIR / "class_list.yaml"
99
+ if not class_list_path.exists():
100
+ raise FileNotFoundError(
101
+ f"class_list.yaml not found: {class_list_path}\n"
102
+ f"MEWC models require class_list.yaml in the model directory."
103
+ )
104
+
105
+ try:
106
+ with open(class_list_path, 'r') as f:
107
+ class_map = yaml.safe_load(f)
108
+ except Exception as e:
109
+ raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
110
+
111
+ # Build class_ids list based on YAML format
112
+ # The YAML can be formatted as either {int: str} or {str: int}
113
+ inv_class = {v: k for k, v in class_map.items()}
114
+
115
+ # Check if keys are numeric (int:label format) or string (label:int format)
116
+ formatted_int_label = _can_all_keys_be_converted_to_int(class_map)
117
+
118
+ if formatted_int_label:
119
+ # Format: {0: "species1", 1: "species2", ...}
120
+ class_ids = [class_map[i] for i in sorted(inv_class.values())]
121
+ else:
122
+ # Format: {"species1": 0, "species2": 1, ...}
123
+ class_ids = sorted(inv_class.values())
124
+
125
+
126
+ def _can_all_keys_be_converted_to_int(d: dict) -> bool:
127
+ """
128
+ Check if all dictionary keys can be converted to integers.
129
+
130
+ Used to determine class_list.yaml format.
131
+
132
+ Args:
133
+ d: Dictionary to check
134
+
135
+ Returns:
136
+ True if all keys are convertible to int, False otherwise
137
+ """
138
+ for key in d.keys():
139
+ try:
140
+ int(key)
141
+ except ValueError:
142
+ return False
143
+ return True
144
+
145
+
146
+ def get_crop(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image | None:
147
+ """
148
+ Crop image using MEWC-specific preprocessing.
149
+
150
+ This cropping method is used by MEWC and follows the MegaDetector
151
+ visualization_utils approach. It:
152
+ 1. Denormalizes the bbox coordinates
153
+ 2. Clips to image boundaries
154
+ 3. Returns the cropped region (no padding or squaring)
155
+
156
+ Reference: https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py#L29
157
+ Reference: https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py#L352
158
+
159
+ Args:
160
+ image: PIL Image (full resolution)
161
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
162
+
163
+ Returns:
164
+ Cropped PIL Image, or None if bbox is invalid
165
+
166
+ Raises:
167
+ None - Returns None for invalid boxes (graceful degradation)
168
+ """
169
+ x1, y1, w_box, h_box = bbox
170
+
171
+ # Check for invalid bounding boxes (zero or negative dimensions)
172
+ if w_box <= 0 or h_box <= 0:
173
+ return None
174
+
175
+ # Convert normalized coordinates to pixel coordinates
176
+ ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
177
+ im_width, im_height = image.size
178
+
179
+ # Denormalize
180
+ left = xmin * im_width
181
+ right = xmax * im_width
182
+ top = ymin * im_height
183
+ bottom = ymax * im_height
184
+
185
+ # Clip to image boundaries (ensure non-negative)
186
+ left = max(left, 0)
187
+ right = max(right, 0)
188
+ top = max(top, 0)
189
+ bottom = max(bottom, 0)
190
+
191
+ # Clip to image boundaries (ensure within image)
192
+ left = min(left, im_width - 1)
193
+ right = min(right, im_width - 1)
194
+ top = min(top, im_height - 1)
195
+ bottom = min(bottom, im_height - 1)
196
+
197
+ # Final check - ensure crop has valid dimensions
198
+ crop_width = right - left
199
+ crop_height = bottom - top
200
+
201
+ if crop_width <= 0 or crop_height <= 0:
202
+ return None
203
+
204
+ # Crop image
205
+ image_cropped = image.crop((left, top, right, bottom))
206
+ return image_cropped
207
+
208
+
209
+ def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
210
+ """
211
+ Run MEWC-Keras classification on cropped image.
212
+
213
+ Workflow:
214
+ 1. Convert PIL Image to numpy array
215
+ 2. Resize to 384x384 (MEWC input size)
216
+ 3. Run model prediction
217
+ 4. Return all class probabilities
218
+
219
+ Args:
220
+ crop: Cropped PIL Image
221
+
222
+ Returns:
223
+ List of (class_id, confidence) tuples for ALL classes, in model order.
224
+ Example: [("10", 0.50674), ("33", 0.46682), ...]
225
+
226
+ Raises:
227
+ RuntimeError: If model not loaded or inference fails
228
+ """
229
+ global animal_model, class_ids
230
+
231
+ if animal_model is None:
232
+ raise RuntimeError("Model not loaded - call load_model() first")
233
+
234
+ if class_ids is None:
235
+ raise RuntimeError("Class IDs not loaded - call load_model() first")
236
+
237
+ if crop is None:
238
+ return []
239
+
240
+ try:
241
+ # Convert PIL Image to numpy array
242
+ img = np.array(crop)
243
+
244
+ if img.size == 0:
245
+ return []
246
+
247
+ # Resize to MEWC input size (384x384)
248
+ img = cv2.resize(img, (img_size, img_size))
249
+
250
+ # Add batch dimension
251
+ img = np.expand_dims(img, axis=0)
252
+
253
+ # Run prediction (verbose=0 suppresses progress bar)
254
+ pred = animal_model.predict(img, verbose=0)[0]
255
+
256
+ # Build list of (class_id, confidence) tuples
257
+ # class_ids is already in the correct order from class_list.yaml
258
+ classifications = []
259
+ for i in range(len(pred)):
260
+ class_id_str = str(i + 1) # 1-indexed for compatibility
261
+ confidence = float(pred[i])
262
+ classifications.append((class_id_str, confidence))
263
+
264
+ return classifications
265
+
266
+ except Exception as e:
267
+ raise RuntimeError(f"MEWC-Keras classification failed: {e}") from e
268
+
269
+
270
+ def get_class_names() -> dict[str, str]:
271
+ """
272
+ Get mapping of class IDs to species names from class_list.yaml.
273
+
274
+ Returns:
275
+ Dict mapping class ID (1-indexed string) to species name
276
+ Example: {"1": "bait", "2": "unknown_animal", ...}
277
+
278
+ Raises:
279
+ RuntimeError: If class_map not loaded
280
+ """
281
+ global class_ids
282
+
283
+ if class_ids is None:
284
+ raise RuntimeError("Class IDs not loaded - call load_model() first")
285
+
286
+ # Build 1-indexed mapping
287
+ class_names = {}
288
+ for i, class_name in enumerate(class_ids):
289
+ class_id_str = str(i + 1) # 1-indexed
290
+ class_names[class_id_str] = class_name
291
+
292
+ return class_names