Addax-Data-Science commited on
Commit
2b37f4a
·
verified ·
1 Parent(s): 614bcb8

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +220 -224
inference.py CHANGED
@@ -14,6 +14,7 @@ MEWC - Mega Efficient Wildlife Classifier - University of Tasmania
14
  Original source: streamlit-AddaxAI/classification/model_types/mewc-keras/classify_detections.py
15
  Reference: https://github.com/zaandahl/mewc
16
  Adapted by: Claude Code on 2026-01-11
 
17
  """
18
 
19
  from __future__ import annotations
@@ -31,258 +32,253 @@ from PIL import Image, ImageFile
31
  # Set Keras backend to JAX (as per original MEWC code)
32
  os.environ["KERAS_BACKEND"] = "jax"
33
 
34
- # Module-level variables (injected by AddaxAI framework)
35
- MODEL_DIR: Path | None = None # Set by CustomInferenceLoader
36
- MODEL_PATH: Path | None = None # Set by CustomInferenceLoader
37
-
38
- # Module-level model instance (loaded once at startup)
39
- animal_model = None
40
- img_size = 384 # MEWC uses 384x384 images
41
-
42
- # Class mapping variables
43
- class_map: dict[str, str] | None = None
44
- class_ids: list[str] | None = None
45
-
46
  # Don't freak out over truncated images
47
  ImageFile.LOAD_TRUNCATED_IMAGES = True
48
 
49
 
50
- def check_gpu() -> bool:
51
- """
52
- Check GPU availability for TensorFlow/Keras inference.
53
-
54
- TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA.
55
-
56
- Returns:
57
- True if GPU available, False otherwise
58
- """
59
- try:
60
- gpus = tf.config.list_logical_devices('GPU')
61
- return len(gpus) > 0
62
- except Exception:
63
- return False
64
-
65
-
66
- def load_model() -> None:
67
- """
68
- Load Keras classification model into memory.
69
-
70
- This function is called once during worker initialization.
71
- The model is stored in the global `animal_model` variable and reused
72
- for all subsequent classification requests.
73
-
74
- Also loads the class_list.yaml file which maps class indices to species names.
75
-
76
- Raises:
77
- RuntimeError: If model loading fails
78
- FileNotFoundError: If MODEL_PATH or class_list.yaml is invalid
79
- """
80
- global animal_model, class_map, class_ids, MODEL_PATH, MODEL_DIR
81
 
82
- # MODEL_PATH and MODEL_DIR are injected by framework before this function is called
83
- # Check that the paths exist (framework guarantees they're not None)
84
- if not MODEL_PATH.exists():
85
- raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
86
 
87
- # Load the Keras model (without compilation for inference only)
88
- try:
89
- animal_model = saving.load_model(str(MODEL_PATH), compile=False)
90
- except Exception as e:
91
- raise RuntimeError(f"Failed to load Keras model from {MODEL_PATH}: {e}") from e
 
 
 
 
 
92
 
93
- # Load class_list.yaml
94
- class_list_path = MODEL_DIR / "class_list.yaml"
95
- if not class_list_path.exists():
96
- raise FileNotFoundError(
97
- f"class_list.yaml not found: {class_list_path}\n"
98
- f"MEWC models require class_list.yaml in the model directory."
99
- )
100
 
101
- try:
102
- with open(class_list_path, 'r') as f:
103
- class_map = yaml.safe_load(f)
104
- except Exception as e:
105
- raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
106
 
107
- # Build class_ids list based on YAML format
108
- # The YAML can be formatted as either {int: str} or {str: int}
109
- inv_class = {v: k for k, v in class_map.items()}
110
-
111
- # Check if keys are numeric (int:label format) or string (label:int format)
112
- formatted_int_label = _can_all_keys_be_converted_to_int(class_map)
113
-
114
- if formatted_int_label:
115
- # Format: {0: "species1", 1: "species2", ...}
116
- class_ids = [class_map[i] for i in sorted(inv_class.values())]
117
- else:
118
- # Format: {"species1": 0, "species2": 1, ...}
119
- class_ids = sorted(inv_class.values())
120
-
121
-
122
- def _can_all_keys_be_converted_to_int(d: dict) -> bool:
123
- """
124
- Check if all dictionary keys can be converted to integers.
125
-
126
- Used to determine class_list.yaml format.
127
-
128
- Args:
129
- d: Dictionary to check
130
-
131
- Returns:
132
- True if all keys are convertible to int, False otherwise
133
- """
134
- for key in d.keys():
135
  try:
136
- int(key)
137
- except ValueError:
 
138
  return False
139
- return True
140
-
141
-
142
- def get_crop(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image | None:
143
- """
144
- Crop image using MEWC-specific preprocessing.
145
-
146
- This cropping method is used by MEWC and follows the MegaDetector
147
- visualization_utils approach. It:
148
- 1. Denormalizes the bbox coordinates
149
- 2. Clips to image boundaries
150
- 3. Returns the cropped region (no padding or squaring)
151
 
152
- Reference: https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py#L29
153
- Reference: https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py#L352
 
154
 
155
- Args:
156
- image: PIL Image (full resolution)
157
- bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
158
 
159
- Returns:
160
- Cropped PIL Image, or None if bbox is invalid
161
 
162
- Raises:
163
- None - Returns None for invalid boxes (graceful degradation)
164
- """
165
- x1, y1, w_box, h_box = bbox
 
 
166
 
167
- # Check for invalid bounding boxes (zero or negative dimensions)
168
- if w_box <= 0 or h_box <= 0:
169
- return None
170
-
171
- # Convert normalized coordinates to pixel coordinates
172
- ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
173
- im_width, im_height = image.size
174
-
175
- # Denormalize
176
- left = xmin * im_width
177
- right = xmax * im_width
178
- top = ymin * im_height
179
- bottom = ymax * im_height
180
-
181
- # Clip to image boundaries (ensure non-negative)
182
- left = max(left, 0)
183
- right = max(right, 0)
184
- top = max(top, 0)
185
- bottom = max(bottom, 0)
186
-
187
- # Clip to image boundaries (ensure within image)
188
- left = min(left, im_width - 1)
189
- right = min(right, im_width - 1)
190
- top = min(top, im_height - 1)
191
- bottom = min(bottom, im_height - 1)
192
-
193
- # Final check - ensure crop has valid dimensions
194
- crop_width = right - left
195
- crop_height = bottom - top
196
-
197
- if crop_width <= 0 or crop_height <= 0:
198
- return None
199
-
200
- # Crop image
201
- image_cropped = image.crop((left, top, right, bottom))
202
- return image_cropped
203
-
204
-
205
- def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
206
- """
207
- Run MEWC-Keras classification on cropped image.
208
-
209
- Workflow:
210
- 1. Convert PIL Image to numpy array
211
- 2. Resize to 384x384 (MEWC input size)
212
- 3. Run model prediction
213
- 4. Return all class probabilities
214
-
215
- Args:
216
- crop: Cropped PIL Image
217
-
218
- Returns:
219
- List of (class_name, confidence) tuples for ALL classes, in model order.
220
- Example: [("tasmanian_pademelon", 0.50674), ("bennetts_wallaby", 0.46682), ...]
221
-
222
- Raises:
223
- RuntimeError: If model not loaded or inference fails
224
- """
225
- global animal_model, class_ids
226
-
227
- if animal_model is None:
228
- raise RuntimeError("Model not loaded - call load_model() first")
229
-
230
- if class_ids is None:
231
- raise RuntimeError("Class IDs not loaded - call load_model() first")
232
-
233
- if crop is None:
234
- return []
235
-
236
- try:
237
- # Convert PIL Image to numpy array
238
- img = np.array(crop)
239
 
240
- if img.size == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  return []
242
 
243
- # Resize to MEWC input size (384x384)
244
- img = cv2.resize(img, (img_size, img_size))
 
245
 
246
- # Add batch dimension
247
- img = np.expand_dims(img, axis=0)
248
 
249
- # Run prediction (verbose=0 suppresses progress bar)
250
- pred = animal_model.predict(img, verbose=0)[0]
251
 
252
- # Build list of (class_name, confidence) tuples
253
- # class_ids is already in the correct order from class_list.yaml
254
- classifications = []
255
- for i in range(len(pred)):
256
- class_name = class_ids[i] # Get species name from class_list.yaml
257
- confidence = float(pred[i])
258
- classifications.append((class_name, confidence))
259
 
260
- return classifications
 
261
 
262
- except Exception as e:
263
- raise RuntimeError(f"MEWC-Keras classification failed: {e}") from e
 
 
 
 
 
264
 
 
265
 
266
- def get_class_names() -> dict[str, str]:
267
- """
268
- Get mapping of class IDs to species names from class_list.yaml.
269
 
270
- Returns:
271
- Dict mapping class ID (1-indexed string) to species name
272
- Example: {"1": "bait", "2": "unknown_animal", ...}
273
 
274
- Raises:
275
- RuntimeError: If class_map not loaded
276
- """
277
- global class_ids
278
 
279
- if class_ids is None:
280
- raise RuntimeError("Class IDs not loaded - call load_model() first")
 
 
 
281
 
282
- # Build 1-indexed mapping
283
- class_names = {}
284
- for i, class_name in enumerate(class_ids):
285
- class_id_str = str(i + 1) # 1-indexed
286
- class_names[class_id_str] = class_name
287
 
288
- return class_names
 
14
  Original source: streamlit-AddaxAI/classification/model_types/mewc-keras/classify_detections.py
15
  Reference: https://github.com/zaandahl/mewc
16
  Adapted by: Claude Code on 2026-01-11
17
+ Updated: 2026-01-13 - Migrated to class-based interface
18
  """
19
 
20
  from __future__ import annotations
 
32
  # Set Keras backend to JAX (as per original MEWC code)
33
  os.environ["KERAS_BACKEND"] = "jax"
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Don't freak out over truncated images
36
  ImageFile.LOAD_TRUNCATED_IMAGES = True
37
 
38
 
39
+ class ModelInference:
40
+ """MEWC-Keras inference implementation for Tasmania species classifier."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ def __init__(self, model_dir: Path, model_path: Path):
43
+ """
44
+ Initialize with model paths.
 
45
 
46
+ Args:
47
+ model_dir: Directory containing model files (including class_list.yaml)
48
+ model_path: Path to tas_ens_mewc.keras file
49
+ """
50
+ self.model_dir = model_dir
51
+ self.model_path = model_path
52
+ self.model = None
53
+ self.img_size = 384 # MEWC uses 384x384 images
54
+ self.class_map: dict[str, str] | None = None
55
+ self.class_ids: list[str] | None = None
56
 
57
+ def check_gpu(self) -> bool:
58
+ """
59
+ Check GPU availability for TensorFlow/Keras inference.
 
 
 
 
60
 
61
+ TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA.
 
 
 
 
62
 
63
+ Returns:
64
+ True if GPU available, False otherwise
65
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  try:
67
+ gpus = tf.config.list_logical_devices('GPU')
68
+ return len(gpus) > 0
69
+ except Exception:
70
  return False
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ def load_model(self) -> None:
73
+ """
74
+ Load Keras classification model into memory.
75
 
76
+ This function is called once during worker initialization.
77
+ The model is stored in self.model and reused for all subsequent
78
+ classification requests.
79
 
80
+ Also loads the class_list.yaml file which maps class indices to species names.
 
81
 
82
+ Raises:
83
+ RuntimeError: If model loading fails
84
+ FileNotFoundError: If model_path or class_list.yaml is invalid
85
+ """
86
+ if not self.model_path.exists():
87
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
88
 
89
+ # Load the Keras model (without compilation for inference only)
90
+ try:
91
+ self.model = saving.load_model(str(self.model_path), compile=False)
92
+ except Exception as e:
93
+ raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e
94
+
95
+ # Load class_list.yaml
96
+ class_list_path = self.model_dir / "class_list.yaml"
97
+ if not class_list_path.exists():
98
+ raise FileNotFoundError(
99
+ f"class_list.yaml not found: {class_list_path}\n"
100
+ f"MEWC models require class_list.yaml in the model directory."
101
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ try:
104
+ with open(class_list_path, 'r') as f:
105
+ self.class_map = yaml.safe_load(f)
106
+ except Exception as e:
107
+ raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
108
+
109
+ # Build class_ids list based on YAML format
110
+ # The YAML can be formatted as either {int: str} or {str: int}
111
+ inv_class = {v: k for k, v in self.class_map.items()}
112
+
113
+ # Check if keys are numeric (int:label format) or string (label:int format)
114
+ formatted_int_label = self._can_all_keys_be_converted_to_int(self.class_map)
115
+
116
+ if formatted_int_label:
117
+ # Format: {0: "species1", 1: "species2", ...}
118
+ self.class_ids = [self.class_map[i] for i in sorted(inv_class.values())]
119
+ else:
120
+ # Format: {"species1": 0, "species2": 1, ...}
121
+ self.class_ids = sorted(inv_class.values())
122
+
123
+ def _can_all_keys_be_converted_to_int(self, d: dict) -> bool:
124
+ """
125
+ Check if all dictionary keys can be converted to integers.
126
+
127
+ Used to determine class_list.yaml format.
128
+
129
+ Args:
130
+ d: Dictionary to check
131
+
132
+ Returns:
133
+ True if all keys are convertible to int, False otherwise
134
+ """
135
+ for key in d.keys():
136
+ try:
137
+ int(key)
138
+ except ValueError:
139
+ return False
140
+ return True
141
+
142
+ def get_crop(
143
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
144
+ ) -> Image.Image | None:
145
+ """
146
+ Crop image using MEWC-specific preprocessing.
147
+
148
+ This cropping method is used by MEWC and follows the MegaDetector
149
+ visualization_utils approach. It:
150
+ 1. Denormalizes the bbox coordinates
151
+ 2. Clips to image boundaries
152
+ 3. Returns the cropped region (no padding or squaring)
153
+
154
+ Reference: https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py#L29
155
+ Reference: https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py#L352
156
+
157
+ Args:
158
+ image: PIL Image (full resolution)
159
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
160
+
161
+ Returns:
162
+ Cropped PIL Image, or None if bbox is invalid
163
+
164
+ Raises:
165
+ None - Returns None for invalid boxes (graceful degradation)
166
+ """
167
+ x1, y1, w_box, h_box = bbox
168
+
169
+ # Check for invalid bounding boxes (zero or negative dimensions)
170
+ if w_box <= 0 or h_box <= 0:
171
+ return None
172
+
173
+ # Convert normalized coordinates to pixel coordinates
174
+ ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
175
+ im_width, im_height = image.size
176
+
177
+ # Denormalize
178
+ left = xmin * im_width
179
+ right = xmax * im_width
180
+ top = ymin * im_height
181
+ bottom = ymax * im_height
182
+
183
+ # Clip to image boundaries (ensure non-negative)
184
+ left = max(left, 0)
185
+ right = max(right, 0)
186
+ top = max(top, 0)
187
+ bottom = max(bottom, 0)
188
+
189
+ # Clip to image boundaries (ensure within image)
190
+ left = min(left, im_width - 1)
191
+ right = min(right, im_width - 1)
192
+ top = min(top, im_height - 1)
193
+ bottom = min(bottom, im_height - 1)
194
+
195
+ # Final check - ensure crop has valid dimensions
196
+ crop_width = right - left
197
+ crop_height = bottom - top
198
+
199
+ if crop_width <= 0 or crop_height <= 0:
200
+ return None
201
+
202
+ # Crop image
203
+ image_cropped = image.crop((left, top, right, bottom))
204
+ return image_cropped
205
+
206
+ def get_classification(self, crop: Image.Image) -> list[tuple[str, float]]:
207
+ """
208
+ Run MEWC-Keras classification on cropped image.
209
+
210
+ Workflow:
211
+ 1. Convert PIL Image to numpy array
212
+ 2. Resize to 384x384 (MEWC input size)
213
+ 3. Run model prediction
214
+ 4. Return all class probabilities
215
+
216
+ Args:
217
+ crop: Cropped PIL Image
218
+
219
+ Returns:
220
+ List of (class_name, confidence) tuples for ALL classes, in model order.
221
+ Example: [("tasmanian_pademelon", 0.50674), ("bennetts_wallaby", 0.46682), ...]
222
+
223
+ Raises:
224
+ RuntimeError: If model not loaded or inference fails
225
+ """
226
+ if self.model is None:
227
+ raise RuntimeError("Model not loaded - call load_model() first")
228
+
229
+ if self.class_ids is None:
230
+ raise RuntimeError("Class IDs not loaded - call load_model() first")
231
+
232
+ if crop is None:
233
  return []
234
 
235
+ try:
236
+ # Convert PIL Image to numpy array
237
+ img = np.array(crop)
238
 
239
+ if img.size == 0:
240
+ return []
241
 
242
+ # Resize to MEWC input size (384x384)
243
+ img = cv2.resize(img, (self.img_size, self.img_size))
244
 
245
+ # Add batch dimension
246
+ img = np.expand_dims(img, axis=0)
 
 
 
 
 
247
 
248
+ # Run prediction (verbose=0 suppresses progress bar)
249
+ pred = self.model.predict(img, verbose=0)[0]
250
 
251
+ # Build list of (class_name, confidence) tuples
252
+ # class_ids is already in the correct order from class_list.yaml
253
+ classifications = []
254
+ for i in range(len(pred)):
255
+ class_name = self.class_ids[i] # Get species name from class_list.yaml
256
+ confidence = float(pred[i])
257
+ classifications.append((class_name, confidence))
258
 
259
+ return classifications
260
 
261
+ except Exception as e:
262
+ raise RuntimeError(f"MEWC-Keras classification failed: {e}") from e
 
263
 
264
+ def get_class_names(self) -> dict[str, str]:
265
+ """
266
+ Get mapping of class IDs to species names from class_list.yaml.
267
 
268
+ Returns:
269
+ Dict mapping class ID (1-indexed string) to species name
270
+ Example: {"1": "bait", "2": "unknown_animal", ...}
 
271
 
272
+ Raises:
273
+ RuntimeError: If class_ids not loaded
274
+ """
275
+ if self.class_ids is None:
276
+ raise RuntimeError("Class IDs not loaded - call load_model() first")
277
 
278
+ # Build 1-indexed mapping
279
+ class_names = {}
280
+ for i, class_name in enumerate(self.class_ids):
281
+ class_id_str = str(i + 1) # 1-indexed
282
+ class_names[class_id_str] = class_name
283
 
284
+ return class_names