| """ |
| Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier) |
| |
| SpeciesNet is an image classifier designed to accelerate the review of images |
| from camera traps. Trained at Google using a large dataset of camera trap images |
| and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels |
| covering diverse animal species, higher-level taxa, and non-animal classes. |
| |
| Model: SpeciesNet v4.0.1a (always_crop variant) |
| Input: 480x480 RGB images (NHWC layout) |
| Framework: PyTorch (torch.fx GraphModule) |
| Classes: 2,498 |
| Developer: Google Research |
| Citation: https://doi.org/10.1049/cvi2.12318 |
| License: https://github.com/google/cameratrapai/blob/main/LICENSE |
| Info: https://github.com/google/cameratrapai |
| |
| Author: Peter van Lunteren |
| """ |
|
|
| from __future__ import annotations |
|
|
| import pathlib |
| import platform |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms.functional as TF |
| from PIL import Image, ImageFile |
|
|
| |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| |
| if platform.system() != "Windows": |
| pathlib.WindowsPath = pathlib.PosixPath |
|
|
| |
| LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt" |
| IMG_SIZE = 480 |
|
|
|
|
| class ModelInference: |
| """SpeciesNet inference implementation using the raw backbone .pt file.""" |
|
|
| def __init__(self, model_dir: Path, model_path: Path): |
| """ |
| Initialize with model paths. |
| |
| Args: |
| model_dir: Directory containing model files |
| model_path: Path to always_crop_...pt file |
| """ |
| self.model_dir = model_dir |
| self.model_path = model_path |
| self.model = None |
| self.device = None |
|
|
| |
| labels_path = model_dir / LABELS_FILENAME |
| if not labels_path.exists(): |
| raise FileNotFoundError(f"Labels file not found: {labels_path}") |
|
|
| self.class_names = [] |
| seen_names: set[str] = set() |
| with open(labels_path) as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| |
| parts = line.split(";") |
| if len(parts) >= 7: |
| common_name = parts[6] |
| else: |
| common_name = parts[-1] |
|
|
| |
| |
| |
| if not common_name or common_name in seen_names: |
| taxonomy = [p for p in parts[1:6] if p] |
| if taxonomy: |
| common_name = taxonomy[-1] |
|
|
| |
| if common_name in seen_names: |
| common_name = f"{common_name} ({parts[0][:8]})" |
|
|
| seen_names.add(common_name) |
| self.class_names.append(common_name) |
|
|
|
|
| def check_gpu(self) -> bool: |
| """Check GPU availability (Apple MPS or NVIDIA CUDA).""" |
| try: |
| if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| return True |
| except Exception: |
| pass |
| return torch.cuda.is_available() |
|
|
| def load_model(self) -> None: |
| """ |
| Load SpeciesNet GraphModule into memory. |
| |
| The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone |
| with classification head). It expects NHWC input layout and outputs |
| logits directly with shape [batch, 2498]. |
| """ |
| if not self.model_path.exists(): |
| raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
| |
| try: |
| if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| self.device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| else: |
| self.device = torch.device("cpu") |
| except Exception: |
| self.device = torch.device("cpu") |
|
|
| |
| self.model = torch.load( |
| self.model_path, map_location=self.device, weights_only=False |
| ) |
| self.model.eval() |
|
|
| def get_crop( |
| self, image: Image.Image, bbox: tuple[float, float, float, float] |
| ) -> Image.Image: |
| """ |
| Crop image using normalized bounding box coordinates. |
| |
| Matches SpeciesNet's preprocessing: crop using int() truncation |
| (not rounding) to match torchvision.transforms.functional.crop(). |
| |
| Args: |
| image: PIL Image (full resolution) |
| bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
| |
| Returns: |
| Cropped PIL Image |
| """ |
| W, H = image.size |
| x, y, w, h = bbox |
|
|
| left = int(x * W) |
| top = int(y * H) |
| crop_w = int(w * W) |
| crop_h = int(h * H) |
|
|
| if crop_w <= 0 or crop_h <= 0: |
| return image |
|
|
| return image.crop((left, top, left + crop_w, top + crop_h)) |
|
|
| def get_classification( |
| self, crop: Image.Image |
| ) -> list[list[str | float]]: |
| """ |
| Run SpeciesNet classification on a cropped image. |
| |
| Args: |
| crop: Cropped and preprocessed PIL Image |
| |
| Returns: |
| List of [class_name, confidence] lists for ALL classes. |
| Sorting by confidence is handled by classification_worker.py. |
| |
| Raises: |
| RuntimeError: If model not loaded or inference fails |
| """ |
| if self.model is None: |
| raise RuntimeError("Model not loaded, call load_model() first") |
|
|
| if crop.mode != "RGB": |
| crop = crop.convert("RGB") |
|
|
| |
| |
| img_tensor = TF.pil_to_tensor(crop) |
| img_tensor = TF.convert_image_dtype(img_tensor, torch.float32) |
| img_tensor = TF.resize( |
| img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False |
| ) |
| img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8) |
| |
| img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0 |
| input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(input_batch) |
| probabilities = F.softmax(logits, dim=1) |
|
|
| probs_np = probabilities.cpu().numpy()[0] |
|
|
| classifications = [] |
| for i, prob in enumerate(probs_np): |
| classifications.append([self.class_names[i], float(prob)]) |
|
|
| return classifications |
|
|
| def get_class_names(self) -> dict[str, str]: |
| """ |
| Get mapping of class IDs to common names from the labels file. |
| |
| Returns: |
| Dict mapping class ID (1-indexed string) to common name. |
| Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...} |
| """ |
| return { |
| str(i + 1): name for i, name in enumerate(self.class_names) |
| } |
|
|
| def get_tensor(self, crop: Image.Image): |
| """Preprocess a crop into a numpy array for batch inference.""" |
| if crop.mode != "RGB": |
| crop = crop.convert("RGB") |
|
|
| img_tensor = TF.pil_to_tensor(crop) |
| img_tensor = TF.convert_image_dtype(img_tensor, torch.float32) |
| img_tensor = TF.resize( |
| img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False |
| ) |
| img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8) |
| return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0 |
|
|
| def classify_batch(self, batch): |
| """Run inference on a batch of preprocessed numpy arrays.""" |
| tensor = torch.from_numpy(batch).to(self.device) |
| with torch.no_grad(): |
| logits = self.model(tensor) |
| probs = F.softmax(logits, dim=1).cpu().numpy() |
|
|
| results = [] |
| for p in probs: |
| classifications = [ |
| [self.class_names[i], float(p[i])] |
| for i in range(len(self.class_names)) |
| ] |
| results.append(classifications) |
| return results |
|
|