File size: 8,593 Bytes
a3274cc d1caf07 a3274cc 7b56719 a3274cc 7b56719 a3274cc d1caf07 a3274cc d1caf07 a3274cc d1caf07 a3274cc d1caf07 a3274cc d1caf07 a3274cc 93eeaa9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 | """
Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier)
SpeciesNet is an image classifier designed to accelerate the review of images
from camera traps. Trained at Google using a large dataset of camera trap images
and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels
covering diverse animal species, higher-level taxa, and non-animal classes.
Model: SpeciesNet v4.0.1a (always_crop variant)
Input: 480x480 RGB images (NHWC layout)
Framework: PyTorch (torch.fx GraphModule)
Classes: 2,498
Developer: Google Research
Citation: https://doi.org/10.1049/cvi2.12318
License: https://github.com/google/cameratrapai/blob/main/LICENSE
Info: https://github.com/google/cameratrapai
Author: Peter van Lunteren
"""
from __future__ import annotations
import pathlib
import platform
from pathlib import Path
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image, ImageFile
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
if platform.system() != "Windows":
pathlib.WindowsPath = pathlib.PosixPath
# Hardcoded model parameters for SpeciesNet v4.0.1a
LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
IMG_SIZE = 480
class ModelInference:
"""SpeciesNet inference implementation using the raw backbone .pt file."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to always_crop_...pt file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model = None
self.device = None
# Parse labels file to get class names
labels_path = model_dir / LABELS_FILENAME
if not labels_path.exists():
raise FileNotFoundError(f"Labels file not found: {labels_path}")
self.class_names = []
seen_names: set[str] = set()
with open(labels_path) as f:
for line in f:
line = line.strip()
if not line:
continue
# Format: UUID;class;order;family;genus;species;common_name
parts = line.split(";")
if len(parts) >= 7:
common_name = parts[6]
else:
common_name = parts[-1]
# Empty or duplicate names cause ID collisions in the
# pipeline's reverse mapping. Fall back to the most
# specific taxonomy rank to create a unique label.
if not common_name or common_name in seen_names:
taxonomy = [p for p in parts[1:6] if p]
if taxonomy:
common_name = taxonomy[-1]
# If still duplicate, append the UUID prefix
if common_name in seen_names:
common_name = f"{common_name} ({parts[0][:8]})"
seen_names.add(common_name)
self.class_names.append(common_name)
def check_gpu(self) -> bool:
"""Check GPU availability (Apple MPS or NVIDIA CUDA)."""
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load SpeciesNet GraphModule into memory.
The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone
with classification head). It expects NHWC input layout and outputs
logits directly with shape [batch, 2498].
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
# Detect device
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
self.device = torch.device("mps")
elif torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
except Exception:
self.device = torch.device("cpu")
# Load the GraphModule (requires weights_only=False for FX deserialization)
self.model = torch.load(
self.model_path, map_location=self.device, weights_only=False
)
self.model.eval()
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using normalized bounding box coordinates.
Matches SpeciesNet's preprocessing: crop using int() truncation
(not rounding) to match torchvision.transforms.functional.crop().
Args:
image: PIL Image (full resolution)
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped PIL Image
"""
W, H = image.size
x, y, w, h = bbox
left = int(x * W)
top = int(y * H)
crop_w = int(w * W)
crop_h = int(h * H)
if crop_w <= 0 or crop_h <= 0:
return image
return image.crop((left, top, left + crop_w, top + crop_h))
def get_classification(
self, crop: Image.Image
) -> list[list[str | float]]:
"""
Run SpeciesNet classification on a cropped image.
Args:
crop: Cropped and preprocessed PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes.
Sorting by confidence is handled by classification_worker.py.
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None:
raise RuntimeError("Model not loaded, call load_model() first")
if crop.mode != "RGB":
crop = crop.convert("RGB")
# Match SpeciesNet's exact preprocessing pipeline:
# PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC
img_tensor = TF.pil_to_tensor(crop)
img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
img_tensor = TF.resize(
img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
)
img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
# HWC float32 [0, 1] (matching speciesnet's img.arr / 255)
img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(input_batch)
probabilities = F.softmax(logits, dim=1)
probs_np = probabilities.cpu().numpy()[0]
classifications = []
for i, prob in enumerate(probs_np):
classifications.append([self.class_names[i], float(prob)])
return classifications
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to common names from the labels file.
Returns:
Dict mapping class ID (1-indexed string) to common name.
Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...}
"""
return {
str(i + 1): name for i, name in enumerate(self.class_names)
}
def get_tensor(self, crop: Image.Image):
"""Preprocess a crop into a numpy array for batch inference."""
if crop.mode != "RGB":
crop = crop.convert("RGB")
img_tensor = TF.pil_to_tensor(crop)
img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
img_tensor = TF.resize(
img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
)
img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
def classify_batch(self, batch):
"""Run inference on a batch of preprocessed numpy arrays."""
tensor = torch.from_numpy(batch).to(self.device)
with torch.no_grad():
logits = self.model(tensor)
probs = F.softmax(logits, dim=1).cpu().numpy()
results = []
for p in probs:
classifications = [
[self.class_names[i], float(p[i])]
for i in range(len(self.class_names))
]
results.append(classifications)
return results
|