|
|
""" |
|
|
Inference script for PAM-SDZWA-v1 (Peruvian Amazon Species Classifier) |
|
|
|
|
|
This model classifies 53 species found in Peruvian Amazon rainforest habitats. |
|
|
Developed by Mathias Tobler from the San Diego Zoo Wildlife Alliance Conservation |
|
|
Technology Lab using their animl-py framework. |
|
|
|
|
|
Model: Peru Amazon v0.86 |
|
|
Input: Variable size (extracted from model config) |
|
|
Framework: TensorFlow/Keras (TensorFlow 1.x compatible) |
|
|
Classes: 53 Amazonian species and taxonomic groups |
|
|
Developer: San Diego Zoo Wildlife Alliance (Mathias Tobler) |
|
|
License: MIT |
|
|
Info: https://github.com/conservationtechlab |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
from PIL import Image, ImageFile |
|
|
from tensorflow.keras.models import load_model |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""TensorFlow/Keras inference implementation for Peruvian Amazon species classifier.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files and class labels |
|
|
model_path: Path to Peru-Amazon_0.86.h5 file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model = None |
|
|
self.img_size = None |
|
|
self.class_map = {} |
|
|
self.class_ids_sorted = [] |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for TensorFlow inference. |
|
|
|
|
|
Returns: |
|
|
True if GPU available, False otherwise |
|
|
""" |
|
|
return len(tf.config.list_logical_devices('GPU')) > 0 |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
Load TensorFlow/Keras model and class labels into memory. |
|
|
|
|
|
This function is called once during worker initialization. |
|
|
The model is stored in self.model and reused for all subsequent |
|
|
classification requests. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If model_path or label file is invalid |
|
|
""" |
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
try: |
|
|
|
|
|
self.model = load_model(str(self.model_path)) |
|
|
|
|
|
|
|
|
|
|
|
self.img_size = self.model.get_config()["layers"][0]["config"]["batch_input_shape"][1] |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e |
|
|
|
|
|
|
|
|
label_file = self.model_dir / "Peru-Amazon_0.86.txt" |
|
|
if not label_file.exists(): |
|
|
raise FileNotFoundError(f"Class label file not found: {label_file}") |
|
|
|
|
|
try: |
|
|
with open(label_file, 'r') as file: |
|
|
for line in file: |
|
|
parts = line.strip().split('"') |
|
|
if len(parts) >= 4: |
|
|
identifier = parts[1].strip() |
|
|
animal_name = parts[3].strip() |
|
|
if identifier.isdigit(): |
|
|
self.class_map[str(identifier)] = str(animal_name) |
|
|
|
|
|
|
|
|
|
|
|
self.class_ids_sorted = sorted(self.class_map.values()) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load class labels from {label_file}: {e}") from e |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Crop image using SDZWA animl-py preprocessing. |
|
|
|
|
|
This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py |
|
|
framework approach with minimal buffering (0 pixels by default). |
|
|
|
|
|
Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py |
|
|
|
|
|
Args: |
|
|
image: PIL Image (full resolution) |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image (not resized - resizing happens in get_classification) |
|
|
|
|
|
Raises: |
|
|
ValueError: If bbox is invalid |
|
|
""" |
|
|
buffer = 0 |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
bbox1, bbox2, bbox3, bbox4 = bbox |
|
|
left = width * bbox1 |
|
|
top = height * bbox2 |
|
|
right = width * (bbox1 + bbox3) |
|
|
bottom = height * (bbox2 + bbox4) |
|
|
|
|
|
|
|
|
left = max(0, int(left) - buffer) |
|
|
top = max(0, int(top) - buffer) |
|
|
right = min(width, int(right) + buffer) |
|
|
bottom = min(height, int(bottom) + buffer) |
|
|
|
|
|
|
|
|
if left >= right or top >= bottom: |
|
|
raise ValueError(f"Invalid bbox dimensions after cropping: left={left}, top={top}, right={right}, bottom={bottom}") |
|
|
|
|
|
|
|
|
image_cropped = image.crop((left, top, right, bottom)) |
|
|
return image_cropped |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run TensorFlow/Keras classification on cropped image. |
|
|
|
|
|
Preprocessing follows SDZWA animl-py framework: |
|
|
- Resize to model input size (extracted from model config) |
|
|
- Convert to numpy array |
|
|
- No normalization or augmentation (except potential horizontal flip during training) |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes, sorted by class ID. |
|
|
Example: [["Black-headed squirrel monkey", 0.001], ["Brazilian rabbit", 0.002], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
img = np.array(crop) |
|
|
|
|
|
|
|
|
img = cv2.resize(img, (self.img_size, self.img_size)) |
|
|
|
|
|
|
|
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred = self.model.predict(img, verbose=0)[0] |
|
|
|
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(pred)): |
|
|
class_name = self.class_ids_sorted[i] |
|
|
confidence = float(pred[i]) |
|
|
classifications.append([class_name, confidence]) |
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Keras classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to species names. |
|
|
|
|
|
Class IDs are 1-indexed and correspond to the sorted order of class names. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to species name |
|
|
Example: {"1": "Black-headed squirrel monkey", "2": "Brazilian rabbit", ...} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
class_names = {} |
|
|
for i, class_name in enumerate(self.class_ids_sorted): |
|
|
class_id_str = str(i + 1) |
|
|
class_names[class_id_str] = class_name |
|
|
|
|
|
return class_names |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to extract class names: {e}") from e |
|
|
|