|
|
""" |
|
|
Inference script for SWUSA-SDZWA-v3 (Southwest USA Species Classifier) |
|
|
|
|
|
This model distinguishes between 27 species native to the Southwest United States. |
|
|
Training data collected by SDZWA and California Mountain Lion Project, with examples |
|
|
from NACTI and CCT datasets. Trained on 91,662 images (70/20/10 split) achieving |
|
|
88% accuracy on test set. |
|
|
|
|
|
Model: Southwest USA v3 |
|
|
Input: 299x299 RGB images |
|
|
Framework: PyTorch (EfficientNet V2 Medium architecture) |
|
|
Classes: 27 species and categories |
|
|
Developer: San Diego Zoo Wildlife Alliance (Kyra Swanson) |
|
|
License: MIT |
|
|
Info: https://github.com/conservationtechlab |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import pathlib |
|
|
import platform |
|
|
from pathlib import Path |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image, ImageFile |
|
|
from torchvision import transforms |
|
|
from torchvision.models import efficientnet |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
plt = platform.system() |
|
|
if plt != 'Windows': |
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
|
|
|
class EfficientNetV2M(nn.Module): |
|
|
"""EfficientNet V2 Medium architecture for SDZWA wildlife classification.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_classes: int, |
|
|
pretrained_weights_path: Path, |
|
|
device_str: str = 'cpu', |
|
|
tune: bool = True |
|
|
): |
|
|
""" |
|
|
Initialize EfficientNet V2 Medium model. |
|
|
|
|
|
Args: |
|
|
num_classes: Number of output classes |
|
|
pretrained_weights_path: Path to ImageNet pretrained weights (.pth file) |
|
|
device_str: Device to load model on ('cpu', 'cuda', 'mps') |
|
|
tune: Whether to enable gradient updates (fine-tuning) |
|
|
""" |
|
|
super(EfficientNetV2M, self).__init__() |
|
|
self.avgpool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
|
|
|
self.model = efficientnet.efficientnet_v2_m(weights=None) |
|
|
self.model.load_state_dict( |
|
|
torch.load(str(pretrained_weights_path), map_location=torch.device(device_str)) |
|
|
) |
|
|
|
|
|
|
|
|
if tune: |
|
|
for params in self.model.parameters(): |
|
|
params.requires_grad = True |
|
|
|
|
|
|
|
|
num_ftrs = self.model.classifier[1].in_features |
|
|
self.model.classifier[1] = nn.Linear(in_features=num_ftrs, out_features=num_classes) |
|
|
self.model.to(torch.device(device_str)) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass (prediction).""" |
|
|
x = self.model.features(x) |
|
|
x = self.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
prediction = self.model.classifier(x) |
|
|
return prediction |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""PyTorch inference implementation for Southwest USA species classifier.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files (classes.csv, pretrained weights) |
|
|
model_path: Path to southwest_v3.pt checkpoint file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model = None |
|
|
self.device = None |
|
|
self.classes = None |
|
|
self.preprocess = None |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for PyTorch inference. |
|
|
|
|
|
Checks both Apple Metal Performance Shaders (MPS) and CUDA availability. |
|
|
|
|
|
Returns: |
|
|
True if GPU available, False otherwise |
|
|
""" |
|
|
|
|
|
try: |
|
|
if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
|
|
return True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
return torch.cuda.is_available() |
|
|
|
|
|
def load_model(self, device_str: str = 'cpu') -> None: |
|
|
""" |
|
|
Load PyTorch EfficientNet model and class labels. |
|
|
|
|
|
This SDZWA model uses EfficientNet V2 Medium architecture with ImageNet |
|
|
pretrained weights, fine-tuned on Southwest USA wildlife data. |
|
|
|
|
|
Args: |
|
|
device_str: Device to load model on ('cpu', 'cuda', or 'mps') |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If required files are missing |
|
|
""" |
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
|
|
|
classes_csv = self.model_dir / 'classes.csv' |
|
|
efficientnet_weights = self.model_dir / 'efficientnet_v2_m-dc08266a.pth' |
|
|
|
|
|
if not classes_csv.exists(): |
|
|
raise FileNotFoundError(f"Classes file not found: {classes_csv}") |
|
|
if not efficientnet_weights.exists(): |
|
|
raise FileNotFoundError(f"EfficientNet weights not found: {efficientnet_weights}") |
|
|
|
|
|
try: |
|
|
|
|
|
self.device = torch.device(device_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classes = pd.read_csv(str(classes_csv)) |
|
|
|
|
|
|
|
|
num_classes = len(self.classes) |
|
|
self.model = EfficientNetV2M( |
|
|
num_classes=num_classes, |
|
|
pretrained_weights_path=efficientnet_weights, |
|
|
device_str=device_str, |
|
|
tune=False |
|
|
) |
|
|
|
|
|
|
|
|
checkpoint = torch.load(str(self.model_path), map_location=self.device) |
|
|
self.model.load_state_dict(checkpoint['model']) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
|
transforms.Resize((299, 299)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load PyTorch model from {self.model_path}: {e}") from e |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Crop image using SDZWA animl-py preprocessing. |
|
|
|
|
|
This cropping method follows the San Diego Zoo Wildlife Alliance's animl-py |
|
|
framework approach with minimal buffering (0 pixels by default). |
|
|
|
|
|
Based on: https://github.com/conservationtechlab/animl-py/blob/main/src/animl/generator.py |
|
|
|
|
|
Args: |
|
|
image: PIL Image (full resolution) |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image (not resized - resizing happens in get_classification) |
|
|
|
|
|
Raises: |
|
|
ValueError: If bbox is invalid |
|
|
""" |
|
|
buffer = 0 |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
bbox1, bbox2, bbox3, bbox4 = bbox |
|
|
left = width * bbox1 |
|
|
top = height * bbox2 |
|
|
right = width * (bbox1 + bbox3) |
|
|
bottom = height * (bbox2 + bbox4) |
|
|
|
|
|
|
|
|
left = max(0, int(left) - buffer) |
|
|
top = max(0, int(top) - buffer) |
|
|
right = min(width, int(right) + buffer) |
|
|
bottom = min(height, int(bottom) + buffer) |
|
|
|
|
|
|
|
|
if left >= right or top >= bottom: |
|
|
raise ValueError( |
|
|
f"Invalid bbox dimensions after cropping: " |
|
|
f"left={left}, top={top}, right={right}, bottom={bottom}" |
|
|
) |
|
|
|
|
|
|
|
|
image_cropped = image.crop((left, top, right, bottom)) |
|
|
return image_cropped |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run PyTorch/EfficientNet classification on cropped image. |
|
|
|
|
|
Preprocessing follows SDZWA animl-py framework: |
|
|
- Resize to 299x299 (as per animl-py specifications) |
|
|
- Convert to tensor |
|
|
- No normalization |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes. |
|
|
Example: [["cougar", 0.85], ["bobcat", 0.10], ["coyote", 0.02], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
input_tensor = self.preprocess(crop) |
|
|
input_batch = input_tensor.unsqueeze(0) |
|
|
input_batch = input_batch.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(input_batch) |
|
|
|
|
|
|
|
|
probabilities = F.softmax(output, dim=1) |
|
|
probabilities_np = probabilities.cpu().detach().numpy() |
|
|
confidence_scores = probabilities_np[0] |
|
|
|
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(confidence_scores)): |
|
|
pred_class = self.classes.iloc[i].values[1] |
|
|
pred_conf = float(confidence_scores[i]) |
|
|
classifications.append([pred_class, pred_conf]) |
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"PyTorch classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to species names from CSV. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to species code |
|
|
Example: {"1": "badger", "2": "beaver", ..., "27": "weasel"} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded |
|
|
""" |
|
|
if self.model is None or self.classes is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
class_names = {} |
|
|
for i in range(len(self.classes)): |
|
|
class_id_str = str(i + 1) |
|
|
class_name = self.classes.iloc[i].values[1] |
|
|
class_names[class_id_str] = class_name |
|
|
|
|
|
return class_names |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to extract class names: {e}") from e |
|
|
|