""" Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier) SpeciesNet is an image classifier designed to accelerate the review of images from camera traps. Trained at Google using a large dataset of camera trap images and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels covering diverse animal species, higher-level taxa, and non-animal classes. Model: SpeciesNet v4.0.1a (always_crop variant) Input: 480x480 RGB images (NHWC layout) Framework: PyTorch (torch.fx GraphModule) Classes: 2,498 Developer: Google Research Citation: https://doi.org/10.1049/cvi2.12318 License: https://github.com/google/cameratrapai/blob/main/LICENSE Info: https://github.com/google/cameratrapai Author: Peter van Lunteren """ from __future__ import annotations import pathlib import platform from pathlib import Path import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from PIL import Image, ImageFile # Don't freak out over truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix if platform.system() != "Windows": pathlib.WindowsPath = pathlib.PosixPath # Hardcoded model parameters for SpeciesNet v4.0.1a LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt" IMG_SIZE = 480 class ModelInference: """SpeciesNet inference implementation using the raw backbone .pt file.""" def __init__(self, model_dir: Path, model_path: Path): """ Initialize with model paths. Args: model_dir: Directory containing model files model_path: Path to always_crop_...pt file """ self.model_dir = model_dir self.model_path = model_path self.model = None self.device = None # Parse labels file to get class names labels_path = model_dir / LABELS_FILENAME if not labels_path.exists(): raise FileNotFoundError(f"Labels file not found: {labels_path}") self.class_names = [] seen_names: set[str] = set() with open(labels_path) as f: for line in f: line = line.strip() if not line: continue # Format: UUID;class;order;family;genus;species;common_name parts = line.split(";") if len(parts) >= 7: common_name = parts[6] else: common_name = parts[-1] # Empty or duplicate names cause ID collisions in the # pipeline's reverse mapping. Fall back to the most # specific taxonomy rank to create a unique label. if not common_name or common_name in seen_names: taxonomy = [p for p in parts[1:6] if p] if taxonomy: common_name = taxonomy[-1] # If still duplicate, append the UUID prefix if common_name in seen_names: common_name = f"{common_name} ({parts[0][:8]})" seen_names.add(common_name) self.class_names.append(common_name) def check_gpu(self) -> bool: """Check GPU availability (Apple MPS or NVIDIA CUDA).""" try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass return torch.cuda.is_available() def load_model(self) -> None: """ Load SpeciesNet GraphModule into memory. The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone with classification head). It expects NHWC input layout and outputs logits directly with shape [batch, 2498]. """ if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") # Detect device try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): self.device = torch.device("mps") elif torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") except Exception: self.device = torch.device("cpu") # Load the GraphModule (requires weights_only=False for FX deserialization) self.model = torch.load( self.model_path, map_location=self.device, weights_only=False ) self.model.eval() def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """ Crop image using normalized bounding box coordinates. Matches SpeciesNet's preprocessing: crop using int() truncation (not rounding) to match torchvision.transforms.functional.crop(). Args: image: PIL Image (full resolution) bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] Returns: Cropped PIL Image """ W, H = image.size x, y, w, h = bbox left = int(x * W) top = int(y * H) crop_w = int(w * W) crop_h = int(h * H) if crop_w <= 0 or crop_h <= 0: return image return image.crop((left, top, left + crop_w, top + crop_h)) def get_classification( self, crop: Image.Image ) -> list[list[str | float]]: """ Run SpeciesNet classification on a cropped image. Args: crop: Cropped and preprocessed PIL Image Returns: List of [class_name, confidence] lists for ALL classes. Sorting by confidence is handled by classification_worker.py. Raises: RuntimeError: If model not loaded or inference fails """ if self.model is None: raise RuntimeError("Model not loaded, call load_model() first") if crop.mode != "RGB": crop = crop.convert("RGB") # Match SpeciesNet's exact preprocessing pipeline: # PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC img_tensor = TF.pil_to_tensor(crop) img_tensor = TF.convert_image_dtype(img_tensor, torch.float32) img_tensor = TF.resize( img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False ) img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8) # HWC float32 [0, 1] (matching speciesnet's img.arr / 255) img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0 input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device) with torch.no_grad(): logits = self.model(input_batch) probabilities = F.softmax(logits, dim=1) probs_np = probabilities.cpu().numpy()[0] classifications = [] for i, prob in enumerate(probs_np): classifications.append([self.class_names[i], float(prob)]) return classifications def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to common names from the labels file. Returns: Dict mapping class ID (1-indexed string) to common name. Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...} """ return { str(i + 1): name for i, name in enumerate(self.class_names) } def get_tensor(self, crop: Image.Image): """Preprocess a crop into a numpy array for batch inference.""" if crop.mode != "RGB": crop = crop.convert("RGB") img_tensor = TF.pil_to_tensor(crop) img_tensor = TF.convert_image_dtype(img_tensor, torch.float32) img_tensor = TF.resize( img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False ) img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8) return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0 def classify_batch(self, batch): """Run inference on a batch of preprocessed numpy arrays.""" tensor = torch.from_numpy(batch).to(self.device) with torch.no_grad(): logits = self.model(tensor) probs = F.softmax(logits, dim=1).cpu().numpy() results = [] for p in probs: classifications = [ [self.class_names[i], float(p[i])] for i in range(len(self.class_names)) ] results.append(classifications) return results