SPECIESNET-v4-0-1-A-v1 / inference.py
Addax-Data-Science's picture
Upload inference.py
93eeaa9 verified
"""
Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier)
SpeciesNet is an image classifier designed to accelerate the review of images
from camera traps. Trained at Google using a large dataset of camera trap images
and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels
covering diverse animal species, higher-level taxa, and non-animal classes.
Model: SpeciesNet v4.0.1a (always_crop variant)
Input: 480x480 RGB images (NHWC layout)
Framework: PyTorch (torch.fx GraphModule)
Classes: 2,498
Developer: Google Research
Citation: https://doi.org/10.1049/cvi2.12318
License: https://github.com/google/cameratrapai/blob/main/LICENSE
Info: https://github.com/google/cameratrapai
Author: Peter van Lunteren
"""
from __future__ import annotations
import pathlib
import platform
from pathlib import Path
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image, ImageFile
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
if platform.system() != "Windows":
pathlib.WindowsPath = pathlib.PosixPath
# Hardcoded model parameters for SpeciesNet v4.0.1a
LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
IMG_SIZE = 480
class ModelInference:
"""SpeciesNet inference implementation using the raw backbone .pt file."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to always_crop_...pt file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.device = None
# Parse labels file to get class names
labels_path = model_dir / LABELS_FILENAME
if not labels_path.exists():
raise FileNotFoundError(f"Labels file not found: {labels_path}")
self.class_names = []
seen_names: set[str] = set()
with open(labels_path) as f:
for line in f:
line = line.strip()
if not line:
continue
# Format: UUID;class;order;family;genus;species;common_name
parts = line.split(";")
if len(parts) >= 7:
common_name = parts[6]
else:
common_name = parts[-1]
# Empty or duplicate names cause ID collisions in the
# pipeline's reverse mapping. Fall back to the most
# specific taxonomy rank to create a unique label.
if not common_name or common_name in seen_names:
taxonomy = [p for p in parts[1:6] if p]
if taxonomy:
common_name = taxonomy[-1]
# If still duplicate, append the UUID prefix
if common_name in seen_names:
common_name = f"{common_name} ({parts[0][:8]})"
seen_names.add(common_name)
self.class_names.append(common_name)
def check_gpu(self) -> bool:
"""Check GPU availability (Apple MPS or NVIDIA CUDA)."""
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load SpeciesNet GraphModule into memory.
The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone
with classification head). It expects NHWC input layout and outputs
logits directly with shape [batch, 2498].
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
# Detect device
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
except Exception:
self.device = torch.device("cpu")
# Load the GraphModule (requires weights_only=False for FX deserialization)
self.model = torch.load(
self.model_path, map_location=self.device, weights_only=False
)
self.model.eval()
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using normalized bounding box coordinates.
Matches SpeciesNet's preprocessing: crop using int() truncation
(not rounding) to match torchvision.transforms.functional.crop().
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped PIL Image
"""
W, H = image.size
x, y, w, h = bbox
left = int(x * W)
top = int(y * H)
crop_w = int(w * W)
crop_h = int(h * H)
if crop_w <= 0 or crop_h <= 0:
return image
return image.crop((left, top, left + crop_w, top + crop_h))
def get_classification(
self, crop: Image.Image
) -> list[list[str | float]]:
"""
Run SpeciesNet classification on a cropped image.
Args:
crop: Cropped and preprocessed PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes.
Sorting by confidence is handled by classification_worker.py.
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded, call load_model() first")
if crop.mode != "RGB":
crop = crop.convert("RGB")
# Match SpeciesNet's exact preprocessing pipeline:
# PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC
img_tensor = TF.pil_to_tensor(crop)
img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
img_tensor = TF.resize(
img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
)
img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
# HWC float32 [0, 1] (matching speciesnet's img.arr / 255)
img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(input_batch)
probabilities = F.softmax(logits, dim=1)
probs_np = probabilities.cpu().numpy()[0]
classifications = []
for i, prob in enumerate(probs_np):
classifications.append([self.class_names[i], float(prob)])
return classifications
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to common names from the labels file.
Returns:
Dict mapping class ID (1-indexed string) to common name.
Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...}
"""
return {
str(i + 1): name for i, name in enumerate(self.class_names)
}
def get_tensor(self, crop: Image.Image):
"""Preprocess a crop into a numpy array for batch inference."""
if crop.mode != "RGB":
crop = crop.convert("RGB")
img_tensor = TF.pil_to_tensor(crop)
img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
img_tensor = TF.resize(
img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
)
img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
def classify_batch(self, batch):
"""Run inference on a batch of preprocessed numpy arrays."""
tensor = torch.from_numpy(batch).to(self.device)
with torch.no_grad():
logits = self.model(tensor)
probs = F.softmax(logits, dim=1).cpu().numpy()
results = []
for p in probs:
classifications = [
[self.class_names[i], float(p[i])]
for i in range(len(self.class_names))
]
results.append(classifications)
return results