Addax-Data-Science commited on
Commit
2a11bfb
·
verified ·
1 Parent(s): 2b37f4a

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +45 -14
inference.py CHANGED
@@ -20,6 +20,7 @@ Updated: 2026-01-13 - Migrated to class-based interface
20
  from __future__ import annotations
21
 
22
  import os
 
23
  from pathlib import Path
24
 
25
  import cv2
@@ -106,19 +107,29 @@ class ModelInference:
106
  except Exception as e:
107
  raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
108
 
109
- # Build class_ids list based on YAML format
110
- # The YAML can be formatted as either {int: str} or {str: int}
111
- inv_class = {v: k for k, v in self.class_map.items()}
 
 
112
 
113
  # Check if keys are numeric (int:label format) or string (label:int format)
114
  formatted_int_label = self._can_all_keys_be_converted_to_int(self.class_map)
115
 
116
  if formatted_int_label:
117
- # Format: {0: "species1", 1: "species2", ...}
118
- self.class_ids = [self.class_map[i] for i in sorted(inv_class.values())]
 
 
 
 
 
 
119
  else:
120
  # Format: {"species1": 0, "species2": 1, ...}
121
- self.class_ids = sorted(inv_class.values())
 
 
122
 
123
  def _can_all_keys_be_converted_to_int(self, d: dict) -> bool:
124
  """
@@ -168,6 +179,7 @@ class ModelInference:
168
 
169
  # Check for invalid bounding boxes (zero or negative dimensions)
170
  if w_box <= 0 or h_box <= 0:
 
171
  return None
172
 
173
  # Convert normalized coordinates to pixel coordinates
@@ -197,13 +209,20 @@ class ModelInference:
197
  crop_height = bottom - top
198
 
199
  if crop_width <= 0 or crop_height <= 0:
 
 
 
 
 
 
 
200
  return None
201
 
202
  # Crop image
203
  image_cropped = image.crop((left, top, right, bottom))
204
  return image_cropped
205
 
206
- def get_classification(self, crop: Image.Image) -> list[tuple[str, float]]:
207
  """
208
  Run MEWC-Keras classification on cropped image.
209
 
@@ -211,14 +230,15 @@ class ModelInference:
211
  1. Convert PIL Image to numpy array
212
  2. Resize to 384x384 (MEWC input size)
213
  3. Run model prediction
214
- 4. Return all class probabilities
215
 
216
  Args:
217
  crop: Cropped PIL Image
218
 
219
  Returns:
220
- List of (class_name, confidence) tuples for ALL classes, in model order.
221
- Example: [("tasmanian_pademelon", 0.50674), ("bennetts_wallaby", 0.46682), ...]
 
222
 
223
  Raises:
224
  RuntimeError: If model not loaded or inference fails
@@ -230,6 +250,7 @@ class ModelInference:
230
  raise RuntimeError("Class IDs not loaded - call load_model() first")
231
 
232
  if crop is None:
 
233
  return []
234
 
235
  try:
@@ -237,6 +258,7 @@ class ModelInference:
237
  img = np.array(crop)
238
 
239
  if img.size == 0:
 
240
  return []
241
 
242
  # Resize to MEWC input size (384x384)
@@ -248,14 +270,16 @@ class ModelInference:
248
  # Run prediction (verbose=0 suppresses progress bar)
249
  pred = self.model.predict(img, verbose=0)[0]
250
 
251
- # Build list of (class_name, confidence) tuples
252
  # class_ids is already in the correct order from class_list.yaml
253
  classifications = []
254
  for i in range(len(pred)):
255
  class_name = self.class_ids[i] # Get species name from class_list.yaml
256
  confidence = float(pred[i])
257
- classifications.append((class_name, confidence))
258
 
 
 
259
  return classifications
260
 
261
  except Exception as e:
@@ -265,9 +289,15 @@ class ModelInference:
265
  """
266
  Get mapping of class IDs to species names from class_list.yaml.
267
 
 
 
 
 
 
 
268
  Returns:
269
  Dict mapping class ID (1-indexed string) to species name
270
- Example: {"1": "bait", "2": "unknown_animal", ...}
271
 
272
  Raises:
273
  RuntimeError: If class_ids not loaded
@@ -275,7 +305,8 @@ class ModelInference:
275
  if self.class_ids is None:
276
  raise RuntimeError("Class IDs not loaded - call load_model() first")
277
 
278
- # Build 1-indexed mapping
 
279
  class_names = {}
280
  for i, class_name in enumerate(self.class_ids):
281
  class_id_str = str(i + 1) # 1-indexed
 
20
  from __future__ import annotations
21
 
22
  import os
23
+ import sys
24
  from pathlib import Path
25
 
26
  import cv2
 
107
  except Exception as e:
108
  raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
109
 
110
+ # The YAML can be formatted as either {int_str: species_name} or {species_name: int}
111
+ # IMPORTANT: The model was trained using LEXICOGRAPHIC sorting of YAML keys!
112
+ # This means '10' comes before '2' in the sorted order, which creates a specific
113
+ # class ordering that the model learned during training. We MUST use the same
114
+ # lexicographic sort to match the model's expectations.
115
 
116
  # Check if keys are numeric (int:label format) or string (label:int format)
117
  formatted_int_label = self._can_all_keys_be_converted_to_int(self.class_map)
118
 
119
  if formatted_int_label:
120
+ # Format: {'0': 'species1', '1': 'species2', ...}
121
+ # Sort keys LEXICOGRAPHICALLY (as strings) to match model training
122
+ # This creates: ['0', '1', '10', '100', '108', '11', '117', '118', '12', '13', '14', ...]
123
+ inv_class = {v: k for k, v in self.class_map.items()}
124
+ yaml_keys_sorted = sorted(inv_class.values()) # Lexicographic sort on string keys
125
+
126
+ # Build dense list: model_output[i] → class_ids[i] = species at yaml_keys_sorted[i]
127
+ self.class_ids = [self.class_map[yaml_key] for yaml_key in yaml_keys_sorted]
128
  else:
129
  # Format: {"species1": 0, "species2": 1, ...}
130
+ # Invert to create list where list[i] = species
131
+ inv_class = {v: k for k, v in self.class_map.items()}
132
+ self.class_ids = [inv_class[i] for i in sorted(inv_class.keys())]
133
 
134
  def _can_all_keys_be_converted_to_int(self, d: dict) -> bool:
135
  """
 
179
 
180
  # Check for invalid bounding boxes (zero or negative dimensions)
181
  if w_box <= 0 or h_box <= 0:
182
+ print(f"[TAS get_crop] Rejecting bbox with zero/negative dims: w={w_box}, h={h_box}", file=sys.stderr, flush=True)
183
  return None
184
 
185
  # Convert normalized coordinates to pixel coordinates
 
209
  crop_height = bottom - top
210
 
211
  if crop_width <= 0 or crop_height <= 0:
212
+ print(
213
+ f"[TAS get_crop] Rejecting bbox after clipping - crop size {crop_width:.1f}x{crop_height:.1f}\n"
214
+ f" Original bbox: x={x1:.4f}, y={y1:.4f}, w={w_box:.4f}, h={h_box:.4f}\n"
215
+ f" Image size: {im_width}x{im_height}\n"
216
+ f" Pixel coords after clip: ({left:.1f},{top:.1f}) to ({right:.1f},{bottom:.1f})",
217
+ file=sys.stderr, flush=True
218
+ )
219
  return None
220
 
221
  # Crop image
222
  image_cropped = image.crop((left, top, right, bottom))
223
  return image_cropped
224
 
225
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
226
  """
227
  Run MEWC-Keras classification on cropped image.
228
 
 
230
  1. Convert PIL Image to numpy array
231
  2. Resize to 384x384 (MEWC input size)
232
  3. Run model prediction
233
+ 4. Return all class probabilities (unsorted - worker handles sorting)
234
 
235
  Args:
236
  crop: Cropped PIL Image
237
 
238
  Returns:
239
+ List of [class_name, confidence] lists for ALL classes, in model order.
240
+ Example: [["unknown_animal", 0.00234], ["tasmanian_pademelon", 0.50674], ...]
241
+ NOTE: Sorting by confidence is handled by classification_worker.py
242
 
243
  Raises:
244
  RuntimeError: If model not loaded or inference fails
 
250
  raise RuntimeError("Class IDs not loaded - call load_model() first")
251
 
252
  if crop is None:
253
+ print("[TAS get_classification] Received None crop, returning empty", file=sys.stderr, flush=True)
254
  return []
255
 
256
  try:
 
258
  img = np.array(crop)
259
 
260
  if img.size == 0:
261
+ print("[TAS get_classification] Zero-size numpy array, returning empty", file=sys.stderr, flush=True)
262
  return []
263
 
264
  # Resize to MEWC input size (384x384)
 
270
  # Run prediction (verbose=0 suppresses progress bar)
271
  pred = self.model.predict(img, verbose=0)[0]
272
 
273
+ # Build list of [class_name, confidence] pairs (as lists, not tuples!)
274
  # class_ids is already in the correct order from class_list.yaml
275
  classifications = []
276
  for i in range(len(pred)):
277
  class_name = self.class_ids[i] # Get species name from class_list.yaml
278
  confidence = float(pred[i])
279
+ classifications.append([class_name, confidence])
280
 
281
+ # NOTE: Sorting by confidence is handled by classification_worker.py
282
+ # Model developers don't need to sort - just return all class predictions
283
  return classifications
284
 
285
  except Exception as e:
 
289
  """
290
  Get mapping of class IDs to species names from class_list.yaml.
291
 
292
+ Returns a 1-indexed contiguous mapping that matches the model's output order.
293
+ The model was trained with lexicographic sorting of YAML keys, so we create
294
+ a simple 1-indexed mapping: {1: species_at_position_0, 2: species_at_position_1, ...}
295
+
296
+ This matches the MegaDetector JSON format and the original MEWC implementation.
297
+
298
  Returns:
299
  Dict mapping class ID (1-indexed string) to species name
300
+ Example: {"1": "unknown_animal", "2": "tasmanian_pademelon", ..., "10": "fallow_deer", ...}
301
 
302
  Raises:
303
  RuntimeError: If class_ids not loaded
 
305
  if self.class_ids is None:
306
  raise RuntimeError("Class IDs not loaded - call load_model() first")
307
 
308
+ # Build 1-indexed mapping: model position i → JSON ID str(i+1)
309
+ # class_ids[0] → "1", class_ids[1] → "2", ..., class_ids[9] → "10" (fallow_deer)
310
  class_names = {}
311
  for i, class_name in enumerate(self.class_ids):
312
  class_id_str = str(i + 1) # 1-indexed