|
|
""" |
|
|
Inference script for TERRAI-NEP-v1 (Terai Region Nepal Species Classifier) |
|
|
|
|
|
This model identifies 10 species in the Terai region of Nepal, designed to support |
|
|
Bengal tiger conservation. Based on EfficientNetV2M following MEWC methodology for |
|
|
image preparation and training. Trained on 2,000 images per class (mostly from LILA BC) |
|
|
achieving 90% accuracy/precision/recall/F1 on test set (250 images per class). |
|
|
|
|
|
Note: Some training data used substitute species due to availability; local Terai |
|
|
images were not available for training, so generalization to the target region remains |
|
|
to be tested. |
|
|
|
|
|
Model: Terai Nepal v1 |
|
|
Input: 224x224 RGB images |
|
|
Framework: Keras 3 with JAX backend (EfficientNetV2M architecture) |
|
|
Classes: 10 species including Bengal tiger, leopard, rhino, elephant |
|
|
Developer: Alexander Merdian-Tarko |
|
|
License: MIT |
|
|
Info: https://alexvmt.github.io/ |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import yaml |
|
|
from keras import saving |
|
|
from PIL import Image, ImageFile |
|
|
|
|
|
|
|
|
os.environ["KERAS_BACKEND"] = "jax" |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""Keras/JAX inference implementation for Terai Nepal species classifier.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files (including class_list.yaml) |
|
|
model_path: Path to model.keras file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model = None |
|
|
self.img_size = 224 |
|
|
self.class_ids = [] |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for TensorFlow/JAX inference. |
|
|
|
|
|
Returns: |
|
|
True if GPU available, False otherwise |
|
|
""" |
|
|
return len(tf.config.list_logical_devices('GPU')) > 0 |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
Load Keras model with JAX backend and class labels. |
|
|
|
|
|
This function is called once during worker initialization. |
|
|
The model is stored in self.model and reused for all subsequent |
|
|
classification requests. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If model_path or class_list.yaml is invalid |
|
|
""" |
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
self.model = saving.load_model(str(self.model_path), compile=False) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e |
|
|
|
|
|
|
|
|
class_list_path = self.model_dir / "class_list.yaml" |
|
|
if not class_list_path.exists(): |
|
|
raise FileNotFoundError(f"Class list file not found: {class_list_path}") |
|
|
|
|
|
try: |
|
|
with open(class_list_path, 'r') as f: |
|
|
class_map = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
inv_class = {v: k for k, v in class_map.items()} |
|
|
|
|
|
|
|
|
|
|
|
formatted_int_label = self._can_all_keys_be_converted_to_int(class_map) |
|
|
|
|
|
if formatted_int_label: |
|
|
|
|
|
self.class_ids = [class_map[i] for i in sorted(inv_class.values())] |
|
|
else: |
|
|
|
|
|
self.class_ids = sorted(inv_class.values()) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load class labels from {class_list_path}: {e}") from e |
|
|
|
|
|
def _can_all_keys_be_converted_to_int(self, d: dict) -> bool: |
|
|
""" |
|
|
Check if all dictionary keys can be converted to integers. |
|
|
|
|
|
Args: |
|
|
d: Dictionary to check |
|
|
|
|
|
Returns: |
|
|
True if all keys are integer strings, False otherwise |
|
|
""" |
|
|
for key in d.keys(): |
|
|
try: |
|
|
int(key) |
|
|
except ValueError: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Crop image using MEWC preprocessing method. |
|
|
|
|
|
This cropping method follows the MEWC-snip approach which is based on |
|
|
MegaDetector's visualization utilities. It performs a direct crop without |
|
|
padding or squaring. |
|
|
|
|
|
Based on: |
|
|
- https://github.com/zaandahl/mewc-snip/blob/main/src/mewc_snip.py |
|
|
- https://github.com/agentmorris/MegaDetector/blob/main/megadetector/visualization/visualization_utils.py |
|
|
|
|
|
Args: |
|
|
image: PIL Image (full resolution) |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image (resizing happens in get_classification) |
|
|
|
|
|
Raises: |
|
|
ValueError: If bbox is invalid |
|
|
""" |
|
|
x1, y1, w_box, h_box = bbox |
|
|
ymin, xmin, ymax, xmax = y1, x1, y1 + h_box, x1 + w_box |
|
|
|
|
|
im_width, im_height = image.size |
|
|
|
|
|
|
|
|
(left, right, top, bottom) = ( |
|
|
xmin * im_width, |
|
|
xmax * im_width, |
|
|
ymin * im_height, |
|
|
ymax * im_height |
|
|
) |
|
|
|
|
|
|
|
|
left = max(left, 0) |
|
|
right = max(right, 0) |
|
|
top = max(top, 0) |
|
|
bottom = max(bottom, 0) |
|
|
left = min(left, im_width - 1) |
|
|
right = min(right, im_width - 1) |
|
|
top = min(top, im_height - 1) |
|
|
bottom = min(bottom, im_height - 1) |
|
|
|
|
|
|
|
|
if left >= right or top >= bottom: |
|
|
raise ValueError( |
|
|
f"Invalid bbox dimensions after cropping: " |
|
|
f"left={left}, top={top}, right={right}, bottom={bottom}" |
|
|
) |
|
|
|
|
|
|
|
|
image_cropped = image.crop((left, top, right, bottom)) |
|
|
return image_cropped |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run Keras/JAX classification on cropped image. |
|
|
|
|
|
Preprocessing: |
|
|
- Convert PIL to numpy array |
|
|
- Resize to 224x224 using OpenCV |
|
|
- Add batch dimension |
|
|
- No normalization (handled internally by model) |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes. |
|
|
Example: [["tiger", 0.85], ["leopard", 0.10], ["black_bear", 0.02], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
img = np.array(crop) |
|
|
|
|
|
|
|
|
img = cv2.resize(img, (self.img_size, self.img_size)) |
|
|
|
|
|
|
|
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
|
|
|
pred = self.model.predict(img, verbose=0)[0] |
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(pred)): |
|
|
class_name = self.class_ids[i] |
|
|
confidence = float(pred[i]) |
|
|
classifications.append([class_name, confidence]) |
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Keras classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to species names. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to species name |
|
|
Example: {"1": "tiger", "2": "leopard", ..., "10": "bird"} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
class_names = {} |
|
|
for i, class_name in enumerate(self.class_ids): |
|
|
class_id_str = str(i + 1) |
|
|
class_names[class_id_str] = class_name |
|
|
|
|
|
return class_names |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to extract class names: {e}") from e |
|
|
|