|
|
""" |
|
|
Inference script for TAS-BB-v1 (Tasmania MEWC Species Classifier) |
|
|
|
|
|
MEWC (Mega Efficient Wildlife Classifier) for Tasmania trained on 2.5 million labelled |
|
|
images from 96 classes. Includes all non-volant terrestrial mammals (native and introduced) |
|
|
and 50+ commonly observed bird species. Overall accuracy and F1 scores exceed 99%. |
|
|
|
|
|
Model: Tasmania MEWC Ensemble |
|
|
Input: 224x224 RGB images |
|
|
Framework: Keras 3 with JAX backend (EfficientNet v2 Small architecture) |
|
|
Classes: 96 Tasmanian terrestrial mammals and birds |
|
|
Developer: Barry Brook (University of Tasmania) |
|
|
Citation: https://ecoevorxiv.org/repository/view/6405/ |
|
|
License: CC BY 4.0 |
|
|
Info: https://github.com/zaandahl/mewc |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import yaml |
|
|
from keras import saving |
|
|
from PIL import Image, ImageFile |
|
|
|
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""MEWC-Keras inference implementation for Tasmania species classifier.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files (including class_list.yaml) |
|
|
model_path: Path to tas_ens_mewc.keras file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model = None |
|
|
self.img_size = 384 |
|
|
self.class_map: dict[str, str] | None = None |
|
|
self.class_ids: list[str] | None = None |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for TensorFlow/Keras inference. |
|
|
|
|
|
TensorFlow can detect GPUs, Metal (Apple Silicon), and CUDA. |
|
|
|
|
|
Returns: |
|
|
True if GPU available, False otherwise |
|
|
""" |
|
|
try: |
|
|
gpus = tf.config.list_logical_devices('GPU') |
|
|
return len(gpus) > 0 |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
Load Keras classification model into memory. |
|
|
|
|
|
This function is called once during worker initialization. |
|
|
The model is stored in self.model and reused for all subsequent |
|
|
classification requests. |
|
|
|
|
|
Also loads the class_list.yaml file which maps class indices to species names. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If model_path or class_list.yaml is invalid |
|
|
""" |
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.model = saving.load_model(str(self.model_path), compile=False) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e |
|
|
|
|
|
|
|
|
class_list_path = self.model_dir / "class_list.yaml" |
|
|
if not class_list_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"class_list.yaml not found: {class_list_path}\n" |
|
|
f"MEWC models require class_list.yaml in the model directory." |
|
|
) |
|
|
|
|
|
try: |
|
|
with open(class_list_path, 'r') as f: |
|
|
self.class_map = yaml.safe_load(f) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load class_list.yaml: {e}") from e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
formatted_int_label = self._can_all_keys_be_converted_to_int(self.class_map) |
|
|
|
|
|
if formatted_int_label: |
|
|
|
|
|
|
|
|
|
|
|
inv_class = {v: k for k, v in self.class_map.items()} |
|
|
yaml_keys_sorted = sorted(inv_class.values()) |
|
|
|
|
|
|
|
|
self.class_ids = [self.class_map[yaml_key] for yaml_key in yaml_keys_sorted] |
|
|
else: |
|
|
|
|
|
|
|
|
inv_class = {v: k for k, v in self.class_map.items()} |
|
|
self.class_ids = [inv_class[i] for i in sorted(inv_class.keys())] |
|
|
|
|
|
def _can_all_keys_be_converted_to_int(self, d: dict) -> bool: |
|
|
""" |
|
|
Check if all dictionary keys can be converted to integers. |
|
|
|
|
|
Used to determine class_list.yaml format. |
|
|
|
|
|
Args: |
|
|
d: Dictionary to check |
|
|
|
|
|
Returns: |
|
|
True if all keys are convertible to int, False otherwise |
|
|
""" |
|
|
for key in d.keys(): |
|
|
try: |
|
|
int(key) |
|
|
except ValueError: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image | None: |
|
|
""" |
|
|
Crop image using MEWC-specific preprocessing. |
|
|
|
|
|
This cropping method is used by MEWC and follows the MegaDetector |
|
|
visualization_utils approach. It: |
|
|
1. Denormalizes the bbox coordinates |
|
|
2. Clips to image boundaries |
|
|
3. Returns the cropped region (no padding or squaring) |
|
|
|
|
|
Reference: https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py#L29 |
|
|
Reference: https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py#L352 |
|
|
|
|
|
Args: |
|
|
image: PIL Image (full resolution) |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image, or None if bbox is invalid |
|
|
|
|
|
Raises: |
|
|
None - Returns None for invalid boxes (graceful degradation) |
|
|
""" |
|
|
x1, y1, w_box, h_box = bbox |
|
|
|
|
|
|
|
|
if w_box <= 0 or h_box <= 0: |
|
|
print(f"[TAS get_crop] Rejecting bbox with zero/negative dims: w={w_box}, h={h_box}", file=sys.stderr, flush=True) |
|
|
return None |
|
|
|
|
|
|
|
|
ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box |
|
|
im_width, im_height = image.size |
|
|
|
|
|
|
|
|
left = xmin * im_width |
|
|
right = xmax * im_width |
|
|
top = ymin * im_height |
|
|
bottom = ymax * im_height |
|
|
|
|
|
|
|
|
left = max(left, 0) |
|
|
right = max(right, 0) |
|
|
top = max(top, 0) |
|
|
bottom = max(bottom, 0) |
|
|
|
|
|
|
|
|
left = min(left, im_width - 1) |
|
|
right = min(right, im_width - 1) |
|
|
top = min(top, im_height - 1) |
|
|
bottom = min(bottom, im_height - 1) |
|
|
|
|
|
|
|
|
crop_width = right - left |
|
|
crop_height = bottom - top |
|
|
|
|
|
if crop_width <= 0 or crop_height <= 0: |
|
|
print( |
|
|
f"[TAS get_crop] Rejecting bbox after clipping - crop size {crop_width:.1f}x{crop_height:.1f}\n" |
|
|
f" Original bbox: x={x1:.4f}, y={y1:.4f}, w={w_box:.4f}, h={h_box:.4f}\n" |
|
|
f" Image size: {im_width}x{im_height}\n" |
|
|
f" Pixel coords after clip: ({left:.1f},{top:.1f}) to ({right:.1f},{bottom:.1f})", |
|
|
file=sys.stderr, flush=True |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
image_cropped = image.crop((left, top, right, bottom)) |
|
|
return image_cropped |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run MEWC-Keras classification on cropped image. |
|
|
|
|
|
Workflow: |
|
|
1. Convert PIL Image to numpy array |
|
|
2. Resize to 384x384 (MEWC input size) |
|
|
3. Run model prediction |
|
|
4. Return all class probabilities (unsorted - worker handles sorting) |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes, in model order. |
|
|
Example: [["unknown_animal", 0.00234], ["tasmanian_pademelon", 0.50674], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
if self.class_ids is None: |
|
|
raise RuntimeError("Class IDs not loaded - call load_model() first") |
|
|
|
|
|
if crop is None: |
|
|
print("[TAS get_classification] Received None crop, returning empty", file=sys.stderr, flush=True) |
|
|
return [] |
|
|
|
|
|
try: |
|
|
|
|
|
img = np.array(crop) |
|
|
|
|
|
if img.size == 0: |
|
|
print("[TAS get_classification] Zero-size numpy array, returning empty", file=sys.stderr, flush=True) |
|
|
return [] |
|
|
|
|
|
|
|
|
img = cv2.resize(img, (self.img_size, self.img_size)) |
|
|
|
|
|
|
|
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
|
|
|
pred = self.model.predict(img, verbose=0)[0] |
|
|
|
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(pred)): |
|
|
class_name = self.class_ids[i] |
|
|
confidence = float(pred[i]) |
|
|
classifications.append([class_name, confidence]) |
|
|
|
|
|
|
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"MEWC-Keras classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to species names from class_list.yaml. |
|
|
|
|
|
Returns a 1-indexed contiguous mapping that matches the model's output order. |
|
|
The model was trained with lexicographic sorting of YAML keys, so we create |
|
|
a simple 1-indexed mapping: {1: species_at_position_0, 2: species_at_position_1, ...} |
|
|
|
|
|
This matches the MegaDetector JSON format and the original MEWC implementation. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to species name |
|
|
Example: {"1": "unknown_animal", "2": "tasmanian_pademelon", ..., "10": "fallow_deer", ...} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If class_ids not loaded |
|
|
""" |
|
|
if self.class_ids is None: |
|
|
raise RuntimeError("Class IDs not loaded - call load_model() first") |
|
|
|
|
|
|
|
|
|
|
|
class_names = {} |
|
|
for i, class_name in enumerate(self.class_ids): |
|
|
class_id_str = str(i + 1) |
|
|
class_names[class_id_str] = class_name |
|
|
|
|
|
return class_names |
|
|
|