File size: 8,291 Bytes
de23f27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
"""
Inference script for PAM-SDZWA-v1 (Peruvian Amazon Species Classifier)
This model classifies 53 species found in Peruvian Amazon rainforest habitats.
Developed by Mathias Tobler from the San Diego Zoo Wildlife Alliance Conservation
Technology Lab using their animl-py framework.
Model: Peru Amazon v0.86
Input: Variable size (extracted from model config)
Framework: TensorFlow/Keras (TensorFlow 1.x compatible)
Classes: 53 Amazonian species and taxonomic groups
Developer: San Diego Zoo Wildlife Alliance (Mathias Tobler)
License: MIT
Info: https://github.com/conservationtechlab
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import os
from pathlib import Path
import cv2
import numpy as np
import tensorflow as tf
from PIL import Image, ImageFile
from tensorflow.keras.models import load_model
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
class ModelInference:
"""TensorFlow/Keras inference implementation for Peruvian Amazon species classifier."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files and class labels
model_path: Path to Peru-Amazon_0.86.h5 file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.img_size = None
self.class_map = {}
self.class_ids_sorted = []
def check_gpu(self) -> bool:
"""
Check GPU availability for TensorFlow inference.
Returns:
True if GPU available, False otherwise
"""
return len(tf.config.list_logical_devices('GPU')) > 0
def load_model(self) -> None:
"""
Load TensorFlow/Keras model and class labels into memory.
This function is called once during worker initialization.
The model is stored in self.model and reused for all subsequent
classification requests.
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path or label file is invalid
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
# Load Keras model
self.model = load_model(str(self.model_path))
# Extract input image size from model config
# Model expects square images (e.g., 299x299)
self.img_size = self.model.get_config()["layers"][0]["config"]["batch_input_shape"][1]
except Exception as e:
raise RuntimeError(f"Failed to load Keras model from {self.model_path}: {e}") from e
# Load class labels from text file
label_file = self.model_dir / "Peru-Amazon_0.86.txt"
if not label_file.exists():
raise FileNotFoundError(f"Class label file not found: {label_file}")
try:
with open(label_file, 'r') as file:
for line in file:
parts = line.strip().split('"')
if len(parts) >= 4:
identifier = parts[1].strip()
animal_name = parts[3].strip()
if identifier.isdigit():
self.class_map[str(identifier)] = str(animal_name)
# Create sorted list of class names (sorted by ID)
# This ensures consistent ordering for inference results
self.class_ids_sorted = sorted(self.class_map.values())
except Exception as e:
raise RuntimeError(f"Failed to load class labels from {label_file}: {e}") from e
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using SDZWA animl-py preprocessing.
This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py
framework approach with minimal buffering (0 pixels by default).
Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped PIL Image (not resized - resizing happens in get_classification)
Raises:
ValueError: If bbox is invalid
"""
buffer = 0 # SDZWA uses 0 pixel buffer
width, height = image.size
# Denormalize bbox coordinates
bbox1, bbox2, bbox3, bbox4 = bbox
left = width * bbox1
top = height * bbox2
right = width * (bbox1 + bbox3)
bottom = height * (bbox2 + bbox4)
# Apply buffer and clip to image boundaries
left = max(0, int(left) - buffer)
top = max(0, int(top) - buffer)
right = min(width, int(right) + buffer)
bottom = min(height, int(bottom) + buffer)
# Validate crop dimensions
if left >= right or top >= bottom:
raise ValueError(f"Invalid bbox dimensions after cropping: left={left}, top={top}, right={right}, bottom={bottom}")
# Crop and return
image_cropped = image.crop((left, top, right, bottom))
return image_cropped
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run TensorFlow/Keras classification on cropped image.
Preprocessing follows SDZWA animl-py framework:
- Resize to model input size (extracted from model config)
- Convert to numpy array
- No normalization or augmentation (except potential horizontal flip during training)
Args:
crop: Cropped PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes, sorted by class ID.
Example: [["Black-headed squirrel monkey", 0.001], ["Brazilian rabbit", 0.002], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Convert PIL to numpy array
img = np.array(crop)
# Resize to model input size using OpenCV
img = cv2.resize(img, (self.img_size, self.img_size))
# Add batch dimension
img = np.expand_dims(img, axis=0)
# Run inference
# Note: According to animl-py, no special preprocessing is needed
# except for horizontal flip augmentation during training
pred = self.model.predict(img, verbose=0)[0]
# Build list of [class_name, confidence] pairs
# Use sorted class IDs to maintain consistent ordering
classifications = []
for i in range(len(pred)):
class_name = self.class_ids_sorted[i]
confidence = float(pred[i])
classifications.append([class_name, confidence])
return classifications
except Exception as e:
raise RuntimeError(f"Keras classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to species names.
Class IDs are 1-indexed and correspond to the sorted order of class names.
Returns:
Dict mapping class ID (1-indexed string) to species name
Example: {"1": "Black-headed squirrel monkey", "2": "Brazilian rabbit", ...}
Raises:
RuntimeError: If model not loaded
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Create 1-indexed mapping of class IDs to names
class_names = {}
for i, class_name in enumerate(self.class_ids_sorted):
class_id_str = str(i + 1) # 1-indexed
class_names[class_id_str] = class_name
return class_names
except Exception as e:
raise RuntimeError(f"Failed to extract class names: {e}") from e
|