Upload inference.py
Browse files- inference.py +220 -224
inference.py
CHANGED
|
@@ -14,6 +14,7 @@ MEWC - Mega Efficient Wildlife Classifier - University of Tasmania
|
|
| 14 |
Original source: streamlit-AddaxAI/classification/model_types/mewc-keras/classify_detections.py
|
| 15 |
Reference: https://github.com/zaandahl/mewc
|
| 16 |
Adapted by: Claude Code on 2026-01-11
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
from __future__ import annotations
|
|
@@ -31,258 +32,253 @@ from PIL import Image, ImageFile
|
|
| 31 |
# Set Keras backend to JAX (as per original MEWC code)
|
| 32 |
os.environ["KERAS_BACKEND"] = "jax"
|
| 33 |
|
| 34 |
-
# Module-level variables (injected by AddaxAI framework)
|
| 35 |
-
MODEL_DIR: Path | None = None # Set by CustomInferenceLoader
|
| 36 |
-
MODEL_PATH: Path | None = None # Set by CustomInferenceLoader
|
| 37 |
-
|
| 38 |
-
# Module-level model instance (loaded once at startup)
|
| 39 |
-
animal_model = None
|
| 40 |
-
img_size = 384 # MEWC uses 384x384 images
|
| 41 |
-
|
| 42 |
-
# Class mapping variables
|
| 43 |
-
class_map: dict[str, str] | None = None
|
| 44 |
-
class_ids: list[str] | None = None
|
| 45 |
-
|
| 46 |
# Don't freak out over truncated images
|
| 47 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
-
"""
|
| 52 |
-
Check GPU availability for TensorFlow/Keras inference.
|
| 53 |
-
|
| 54 |
-
TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA.
|
| 55 |
-
|
| 56 |
-
Returns:
|
| 57 |
-
True if GPU available, False otherwise
|
| 58 |
-
"""
|
| 59 |
-
try:
|
| 60 |
-
gpus = tf.config.list_logical_devices('GPU')
|
| 61 |
-
return len(gpus) > 0
|
| 62 |
-
except Exception:
|
| 63 |
-
return False
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def load_model() -> None:
|
| 67 |
-
"""
|
| 68 |
-
Load Keras classification model into memory.
|
| 69 |
-
|
| 70 |
-
This function is called once during worker initialization.
|
| 71 |
-
The model is stored in the global `animal_model` variable and reused
|
| 72 |
-
for all subsequent classification requests.
|
| 73 |
-
|
| 74 |
-
Also loads the class_list.yaml file which maps class indices to species names.
|
| 75 |
-
|
| 76 |
-
Raises:
|
| 77 |
-
RuntimeError: If model loading fails
|
| 78 |
-
FileNotFoundError: If MODEL_PATH or class_list.yaml is invalid
|
| 79 |
-
"""
|
| 80 |
-
global animal_model, class_map, class_ids, MODEL_PATH, MODEL_DIR
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
raise FileNotFoundError(f"Model file not found: {MODEL_PATH}")
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
raise FileNotFoundError(
|
| 97 |
-
f"class_list.yaml not found: {class_list_path}\n"
|
| 98 |
-
f"MEWC models require class_list.yaml in the model directory."
|
| 99 |
-
)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
with open(class_list_path, 'r') as f:
|
| 103 |
-
class_map = yaml.safe_load(f)
|
| 104 |
-
except Exception as e:
|
| 105 |
-
raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
# Check if keys are numeric (int:label format) or string (label:int format)
|
| 112 |
-
formatted_int_label = _can_all_keys_be_converted_to_int(class_map)
|
| 113 |
-
|
| 114 |
-
if formatted_int_label:
|
| 115 |
-
# Format: {0: "species1", 1: "species2", ...}
|
| 116 |
-
class_ids = [class_map[i] for i in sorted(inv_class.values())]
|
| 117 |
-
else:
|
| 118 |
-
# Format: {"species1": 0, "species2": 1, ...}
|
| 119 |
-
class_ids = sorted(inv_class.values())
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
def _can_all_keys_be_converted_to_int(d: dict) -> bool:
|
| 123 |
-
"""
|
| 124 |
-
Check if all dictionary keys can be converted to integers.
|
| 125 |
-
|
| 126 |
-
Used to determine class_list.yaml format.
|
| 127 |
-
|
| 128 |
-
Args:
|
| 129 |
-
d: Dictionary to check
|
| 130 |
-
|
| 131 |
-
Returns:
|
| 132 |
-
True if all keys are convertible to int, False otherwise
|
| 133 |
-
"""
|
| 134 |
-
for key in d.keys():
|
| 135 |
try:
|
| 136 |
-
|
| 137 |
-
|
|
|
|
| 138 |
return False
|
| 139 |
-
return True
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
def get_crop(image: Image.Image, bbox: tuple[float, float, float, float]) -> Image.Image | None:
|
| 143 |
-
"""
|
| 144 |
-
Crop image using MEWC-specific preprocessing.
|
| 145 |
-
|
| 146 |
-
This cropping method is used by MEWC and follows the MegaDetector
|
| 147 |
-
visualization_utils approach. It:
|
| 148 |
-
1. Denormalizes the bbox coordinates
|
| 149 |
-
2. Clips to image boundaries
|
| 150 |
-
3. Returns the cropped region (no padding or squaring)
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
|
|
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
Cropped PIL Image, or None if bbox is invalid
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
# Clip to image boundaries (ensure non-negative)
|
| 182 |
-
left = max(left, 0)
|
| 183 |
-
right = max(right, 0)
|
| 184 |
-
top = max(top, 0)
|
| 185 |
-
bottom = max(bottom, 0)
|
| 186 |
-
|
| 187 |
-
# Clip to image boundaries (ensure within image)
|
| 188 |
-
left = min(left, im_width - 1)
|
| 189 |
-
right = min(right, im_width - 1)
|
| 190 |
-
top = min(top, im_height - 1)
|
| 191 |
-
bottom = min(bottom, im_height - 1)
|
| 192 |
-
|
| 193 |
-
# Final check - ensure crop has valid dimensions
|
| 194 |
-
crop_width = right - left
|
| 195 |
-
crop_height = bottom - top
|
| 196 |
-
|
| 197 |
-
if crop_width <= 0 or crop_height <= 0:
|
| 198 |
-
return None
|
| 199 |
-
|
| 200 |
-
# Crop image
|
| 201 |
-
image_cropped = image.crop((left, top, right, bottom))
|
| 202 |
-
return image_cropped
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def get_classification(crop: Image.Image) -> list[tuple[str, float]]:
|
| 206 |
-
"""
|
| 207 |
-
Run MEWC-Keras classification on cropped image.
|
| 208 |
-
|
| 209 |
-
Workflow:
|
| 210 |
-
1. Convert PIL Image to numpy array
|
| 211 |
-
2. Resize to 384x384 (MEWC input size)
|
| 212 |
-
3. Run model prediction
|
| 213 |
-
4. Return all class probabilities
|
| 214 |
-
|
| 215 |
-
Args:
|
| 216 |
-
crop: Cropped PIL Image
|
| 217 |
-
|
| 218 |
-
Returns:
|
| 219 |
-
List of (class_name, confidence) tuples for ALL classes, in model order.
|
| 220 |
-
Example: [("tasmanian_pademelon", 0.50674), ("bennetts_wallaby", 0.46682), ...]
|
| 221 |
-
|
| 222 |
-
Raises:
|
| 223 |
-
RuntimeError: If model not loaded or inference fails
|
| 224 |
-
"""
|
| 225 |
-
global animal_model, class_ids
|
| 226 |
-
|
| 227 |
-
if animal_model is None:
|
| 228 |
-
raise RuntimeError("Model not loaded - call load_model() first")
|
| 229 |
-
|
| 230 |
-
if class_ids is None:
|
| 231 |
-
raise RuntimeError("Class IDs not loaded - call load_model() first")
|
| 232 |
-
|
| 233 |
-
if crop is None:
|
| 234 |
-
return []
|
| 235 |
-
|
| 236 |
-
try:
|
| 237 |
-
# Convert PIL Image to numpy array
|
| 238 |
-
img = np.array(crop)
|
| 239 |
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
return []
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
classifications = []
|
| 255 |
-
for i in range(len(pred)):
|
| 256 |
-
class_name = class_ids[i] # Get species name from class_list.yaml
|
| 257 |
-
confidence = float(pred[i])
|
| 258 |
-
classifications.append((class_name, confidence))
|
| 259 |
|
| 260 |
-
|
|
|
|
| 261 |
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
|
|
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
Get mapping of class IDs to species names from class_list.yaml.
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
global class_ids
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
|
| 288 |
-
|
|
|
|
| 14 |
Original source: streamlit-AddaxAI/classification/model_types/mewc-keras/classify_detections.py
|
| 15 |
Reference: https://github.com/zaandahl/mewc
|
| 16 |
Adapted by: Claude Code on 2026-01-11
|
| 17 |
+
Updated: 2026-01-13 - Migrated to class-based interface
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
|
|
| 32 |
# Set Keras backend to JAX (as per original MEWC code)
|
| 33 |
os.environ["KERAS_BACKEND"] = "jax"
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Don't freak out over truncated images
|
| 36 |
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 37 |
|
| 38 |
|
| 39 |
+
class ModelInference:
|
| 40 |
+
"""MEWC-Keras inference implementation for Tasmania species classifier."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
def __init__(self, model_dir: Path, model_path: Path):
|
| 43 |
+
"""
|
| 44 |
+
Initialize with model paths.
|
|
|
|
| 45 |
|
| 46 |
+
Args:
|
| 47 |
+
model_dir: Directory containing model files (including class_list.yaml)
|
| 48 |
+
model_path: Path to tas_ens_mewc.keras file
|
| 49 |
+
"""
|
| 50 |
+
self.model_dir = model_dir
|
| 51 |
+
self.model_path = model_path
|
| 52 |
+
self.model = None
|
| 53 |
+
self.img_size = 384 # MEWC uses 384x384 images
|
| 54 |
+
self.class_map: dict[str, str] | None = None
|
| 55 |
+
self.class_ids: list[str] | None = None
|
| 56 |
|
| 57 |
+
def check_gpu(self) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
Check GPU availability for TensorFlow/Keras inference.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
+
TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
Returns:
|
| 64 |
+
True if GPU available, False otherwise
|
| 65 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
try:
|
| 67 |
+
gpus = tf.config.list_logical_devices('GPU')
|
| 68 |
+
return len(gpus) > 0
|
| 69 |
+
except Exception:
|
| 70 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
+
def load_model(self) -> None:
|
| 73 |
+
"""
|
| 74 |
+
Load Keras classification model into memory.
|
| 75 |
|
| 76 |
+
This function is called once during worker initialization.
|
| 77 |
+
The model is stored in self.model and reused for all subsequent
|
| 78 |
+
classification requests.
|
| 79 |
|
| 80 |
+
Also loads the class_list.yaml file which maps class indices to species names.
|
|
|
|
| 81 |
|
| 82 |
+
Raises:
|
| 83 |
+
RuntimeError: If model loading fails
|
| 84 |
+
FileNotFoundError: If model_path or class_list.yaml is invalid
|
| 85 |
+
"""
|
| 86 |
+
if not self.model_path.exists():
|
| 87 |
+
raise FileNotFoundError(f"Model file not found: {self.model_path}")
|
| 88 |
|
| 89 |
+
# Load the Keras model (without compilation for inference only)
|
| 90 |
+
try:
|
| 91 |
+
self.model = saving.load_model(str(self.model_path), compile=False)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e
|
| 94 |
+
|
| 95 |
+
# Load class_list.yaml
|
| 96 |
+
class_list_path = self.model_dir / "class_list.yaml"
|
| 97 |
+
if not class_list_path.exists():
|
| 98 |
+
raise FileNotFoundError(
|
| 99 |
+
f"class_list.yaml not found: {class_list_path}\n"
|
| 100 |
+
f"MEWC models require class_list.yaml in the model directory."
|
| 101 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
try:
|
| 104 |
+
with open(class_list_path, 'r') as f:
|
| 105 |
+
self.class_map = yaml.safe_load(f)
|
| 106 |
+
except Exception as e:
|
| 107 |
+
raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e
|
| 108 |
+
|
| 109 |
+
# Build class_ids list based on YAML format
|
| 110 |
+
# The YAML can be formatted as either {int: str} or {str: int}
|
| 111 |
+
inv_class = {v: k for k, v in self.class_map.items()}
|
| 112 |
+
|
| 113 |
+
# Check if keys are numeric (int:label format) or string (label:int format)
|
| 114 |
+
formatted_int_label = self._can_all_keys_be_converted_to_int(self.class_map)
|
| 115 |
+
|
| 116 |
+
if formatted_int_label:
|
| 117 |
+
# Format: {0: "species1", 1: "species2", ...}
|
| 118 |
+
self.class_ids = [self.class_map[i] for i in sorted(inv_class.values())]
|
| 119 |
+
else:
|
| 120 |
+
# Format: {"species1": 0, "species2": 1, ...}
|
| 121 |
+
self.class_ids = sorted(inv_class.values())
|
| 122 |
+
|
| 123 |
+
def _can_all_keys_be_converted_to_int(self, d: dict) -> bool:
|
| 124 |
+
"""
|
| 125 |
+
Check if all dictionary keys can be converted to integers.
|
| 126 |
+
|
| 127 |
+
Used to determine class_list.yaml format.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
d: Dictionary to check
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
True if all keys are convertible to int, False otherwise
|
| 134 |
+
"""
|
| 135 |
+
for key in d.keys():
|
| 136 |
+
try:
|
| 137 |
+
int(key)
|
| 138 |
+
except ValueError:
|
| 139 |
+
return False
|
| 140 |
+
return True
|
| 141 |
+
|
| 142 |
+
def get_crop(
|
| 143 |
+
self, image: Image.Image, bbox: tuple[float, float, float, float]
|
| 144 |
+
) -> Image.Image | None:
|
| 145 |
+
"""
|
| 146 |
+
Crop image using MEWC-specific preprocessing.
|
| 147 |
+
|
| 148 |
+
This cropping method is used by MEWC and follows the MegaDetector
|
| 149 |
+
visualization_utils approach. It:
|
| 150 |
+
1. Denormalizes the bbox coordinates
|
| 151 |
+
2. Clips to image boundaries
|
| 152 |
+
3. Returns the cropped region (no padding or squaring)
|
| 153 |
+
|
| 154 |
+
Reference: https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py#L29
|
| 155 |
+
Reference: https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py#L352
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
image: PIL Image (full resolution)
|
| 159 |
+
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Cropped PIL Image, or None if bbox is invalid
|
| 163 |
+
|
| 164 |
+
Raises:
|
| 165 |
+
None - Returns None for invalid boxes (graceful degradation)
|
| 166 |
+
"""
|
| 167 |
+
x1, y1, w_box, h_box = bbox
|
| 168 |
+
|
| 169 |
+
# Check for invalid bounding boxes (zero or negative dimensions)
|
| 170 |
+
if w_box <= 0 or h_box <= 0:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
# Convert normalized coordinates to pixel coordinates
|
| 174 |
+
ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box
|
| 175 |
+
im_width, im_height = image.size
|
| 176 |
+
|
| 177 |
+
# Denormalize
|
| 178 |
+
left = xmin * im_width
|
| 179 |
+
right = xmax * im_width
|
| 180 |
+
top = ymin * im_height
|
| 181 |
+
bottom = ymax * im_height
|
| 182 |
+
|
| 183 |
+
# Clip to image boundaries (ensure non-negative)
|
| 184 |
+
left = max(left, 0)
|
| 185 |
+
right = max(right, 0)
|
| 186 |
+
top = max(top, 0)
|
| 187 |
+
bottom = max(bottom, 0)
|
| 188 |
+
|
| 189 |
+
# Clip to image boundaries (ensure within image)
|
| 190 |
+
left = min(left, im_width - 1)
|
| 191 |
+
right = min(right, im_width - 1)
|
| 192 |
+
top = min(top, im_height - 1)
|
| 193 |
+
bottom = min(bottom, im_height - 1)
|
| 194 |
+
|
| 195 |
+
# Final check - ensure crop has valid dimensions
|
| 196 |
+
crop_width = right - left
|
| 197 |
+
crop_height = bottom - top
|
| 198 |
+
|
| 199 |
+
if crop_width <= 0 or crop_height <= 0:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
# Crop image
|
| 203 |
+
image_cropped = image.crop((left, top, right, bottom))
|
| 204 |
+
return image_cropped
|
| 205 |
+
|
| 206 |
+
def get_classification(self, crop: Image.Image) -> list[tuple[str, float]]:
|
| 207 |
+
"""
|
| 208 |
+
Run MEWC-Keras classification on cropped image.
|
| 209 |
+
|
| 210 |
+
Workflow:
|
| 211 |
+
1. Convert PIL Image to numpy array
|
| 212 |
+
2. Resize to 384x384 (MEWC input size)
|
| 213 |
+
3. Run model prediction
|
| 214 |
+
4. Return all class probabilities
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
crop: Cropped PIL Image
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
List of (class_name, confidence) tuples for ALL classes, in model order.
|
| 221 |
+
Example: [("tasmanian_pademelon", 0.50674), ("bennetts_wallaby", 0.46682), ...]
|
| 222 |
+
|
| 223 |
+
Raises:
|
| 224 |
+
RuntimeError: If model not loaded or inference fails
|
| 225 |
+
"""
|
| 226 |
+
if self.model is None:
|
| 227 |
+
raise RuntimeError("Model not loaded - call load_model() first")
|
| 228 |
+
|
| 229 |
+
if self.class_ids is None:
|
| 230 |
+
raise RuntimeError("Class IDs not loaded - call load_model() first")
|
| 231 |
+
|
| 232 |
+
if crop is None:
|
| 233 |
return []
|
| 234 |
|
| 235 |
+
try:
|
| 236 |
+
# Convert PIL Image to numpy array
|
| 237 |
+
img = np.array(crop)
|
| 238 |
|
| 239 |
+
if img.size == 0:
|
| 240 |
+
return []
|
| 241 |
|
| 242 |
+
# Resize to MEWC input size (384x384)
|
| 243 |
+
img = cv2.resize(img, (self.img_size, self.img_size))
|
| 244 |
|
| 245 |
+
# Add batch dimension
|
| 246 |
+
img = np.expand_dims(img, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
# Run prediction (verbose=0 suppresses progress bar)
|
| 249 |
+
pred = self.model.predict(img, verbose=0)[0]
|
| 250 |
|
| 251 |
+
# Build list of (class_name, confidence) tuples
|
| 252 |
+
# class_ids is already in the correct order from class_list.yaml
|
| 253 |
+
classifications = []
|
| 254 |
+
for i in range(len(pred)):
|
| 255 |
+
class_name = self.class_ids[i] # Get species name from class_list.yaml
|
| 256 |
+
confidence = float(pred[i])
|
| 257 |
+
classifications.append((class_name, confidence))
|
| 258 |
|
| 259 |
+
return classifications
|
| 260 |
|
| 261 |
+
except Exception as e:
|
| 262 |
+
raise RuntimeError(f"MEWC-Keras classification failed: {e}") from e
|
|
|
|
| 263 |
|
| 264 |
+
def get_class_names(self) -> dict[str, str]:
|
| 265 |
+
"""
|
| 266 |
+
Get mapping of class IDs to species names from class_list.yaml.
|
| 267 |
|
| 268 |
+
Returns:
|
| 269 |
+
Dict mapping class ID (1-indexed string) to species name
|
| 270 |
+
Example: {"1": "bait", "2": "unknown_animal", ...}
|
|
|
|
| 271 |
|
| 272 |
+
Raises:
|
| 273 |
+
RuntimeError: If class_ids not loaded
|
| 274 |
+
"""
|
| 275 |
+
if self.class_ids is None:
|
| 276 |
+
raise RuntimeError("Class IDs not loaded - call load_model() first")
|
| 277 |
|
| 278 |
+
# Build 1-indexed mapping
|
| 279 |
+
class_names = {}
|
| 280 |
+
for i, class_name in enumerate(self.class_ids):
|
| 281 |
+
class_id_str = str(i + 1) # 1-indexed
|
| 282 |
+
class_names[class_id_str] = class_name
|
| 283 |
|
| 284 |
+
return class_names
|