File size: 8,593 Bytes
a3274cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1caf07
a3274cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b56719
a3274cc
 
 
 
 
 
 
 
 
 
 
7b56719
 
 
 
 
 
 
 
 
 
 
 
 
 
a3274cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1caf07
 
a3274cc
 
 
 
 
 
 
 
 
 
 
d1caf07
 
 
 
a3274cc
d1caf07
a3274cc
 
d1caf07
a3274cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1caf07
 
 
 
 
 
 
 
 
 
 
a3274cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93eeaa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
"""
Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier)

SpeciesNet is an image classifier designed to accelerate the review of images
from camera traps. Trained at Google using a large dataset of camera trap images
and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels
covering diverse animal species, higher-level taxa, and non-animal classes.

Model: SpeciesNet v4.0.1a (always_crop variant)
Input: 480x480 RGB images (NHWC layout)
Framework: PyTorch (torch.fx GraphModule)
Classes: 2,498
Developer: Google Research
Citation: https://doi.org/10.1049/cvi2.12318
License: https://github.com/google/cameratrapai/blob/main/LICENSE
Info: https://github.com/google/cameratrapai

Author: Peter van Lunteren
"""

from __future__ import annotations

import pathlib
import platform
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image, ImageFile

# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Make sure Windows-trained models work on Unix
if platform.system() != "Windows":
    pathlib.WindowsPath = pathlib.PosixPath

# Hardcoded model parameters for SpeciesNet v4.0.1a
LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
IMG_SIZE = 480


class ModelInference:
    """SpeciesNet inference implementation using the raw backbone .pt file."""

    def __init__(self, model_dir: Path, model_path: Path):
        """
        Initialize with model paths.

        Args:
            model_dir: Directory containing model files
            model_path: Path to always_crop_...pt file
        """
        self.model_dir = model_dir
        self.model_path = model_path
        self.model = None
        self.device = None

        # Parse labels file to get class names
        labels_path = model_dir / LABELS_FILENAME
        if not labels_path.exists():
            raise FileNotFoundError(f"Labels file not found: {labels_path}")

        self.class_names = []
        seen_names: set[str] = set()
        with open(labels_path) as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                # Format: UUID;class;order;family;genus;species;common_name
                parts = line.split(";")
                if len(parts) >= 7:
                    common_name = parts[6]
                else:
                    common_name = parts[-1]

                # Empty or duplicate names cause ID collisions in the
                # pipeline's reverse mapping. Fall back to the most
                # specific taxonomy rank to create a unique label.
                if not common_name or common_name in seen_names:
                    taxonomy = [p for p in parts[1:6] if p]
                    if taxonomy:
                        common_name = taxonomy[-1]

                # If still duplicate, append the UUID prefix
                if common_name in seen_names:
                    common_name = f"{common_name} ({parts[0][:8]})"

                seen_names.add(common_name)
                self.class_names.append(common_name)


    def check_gpu(self) -> bool:
        """Check GPU availability (Apple MPS or NVIDIA CUDA)."""
        try:
            if torch.backends.mps.is_built() and torch.backends.mps.is_available():
                return True
        except Exception:
            pass
        return torch.cuda.is_available()

    def load_model(self) -> None:
        """
        Load SpeciesNet GraphModule into memory.

        The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone
        with classification head). It expects NHWC input layout and outputs
        logits directly with shape [batch, 2498].
        """
        if not self.model_path.exists():
            raise FileNotFoundError(f"Model file not found: {self.model_path}")

        # Detect device
        try:
            if torch.backends.mps.is_built() and torch.backends.mps.is_available():
                self.device = torch.device("mps")
            elif torch.cuda.is_available():
                self.device = torch.device("cuda")
            else:
                self.device = torch.device("cpu")
        except Exception:
            self.device = torch.device("cpu")

        # Load the GraphModule (requires weights_only=False for FX deserialization)
        self.model = torch.load(
            self.model_path, map_location=self.device, weights_only=False
        )
        self.model.eval()

    def get_crop(
        self, image: Image.Image, bbox: tuple[float, float, float, float]
    ) -> Image.Image:
        """
        Crop image using normalized bounding box coordinates.

        Matches SpeciesNet's preprocessing: crop using int() truncation
        (not rounding) to match torchvision.transforms.functional.crop().

        Args:
            image: PIL Image (full resolution)
            bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]

        Returns:
            Cropped PIL Image
        """
        W, H = image.size
        x, y, w, h = bbox

        left = int(x * W)
        top = int(y * H)
        crop_w = int(w * W)
        crop_h = int(h * H)

        if crop_w <= 0 or crop_h <= 0:
            return image

        return image.crop((left, top, left + crop_w, top + crop_h))

    def get_classification(
        self, crop: Image.Image
    ) -> list[list[str | float]]:
        """
        Run SpeciesNet classification on a cropped image.

        Args:
            crop: Cropped and preprocessed PIL Image

        Returns:
            List of [class_name, confidence] lists for ALL classes.
            Sorting by confidence is handled by classification_worker.py.

        Raises:
            RuntimeError: If model not loaded or inference fails
        """
        if self.model is None:
            raise RuntimeError("Model not loaded, call load_model() first")

        if crop.mode != "RGB":
            crop = crop.convert("RGB")

        # Match SpeciesNet's exact preprocessing pipeline:
        # PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC
        img_tensor = TF.pil_to_tensor(crop)
        img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
        img_tensor = TF.resize(
            img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
        )
        img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
        # HWC float32 [0, 1] (matching speciesnet's img.arr / 255)
        img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
        input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)

        with torch.no_grad():
            logits = self.model(input_batch)
            probabilities = F.softmax(logits, dim=1)

        probs_np = probabilities.cpu().numpy()[0]

        classifications = []
        for i, prob in enumerate(probs_np):
            classifications.append([self.class_names[i], float(prob)])

        return classifications

    def get_class_names(self) -> dict[str, str]:
        """
        Get mapping of class IDs to common names from the labels file.

        Returns:
            Dict mapping class ID (1-indexed string) to common name.
            Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...}
        """
        return {
            str(i + 1): name for i, name in enumerate(self.class_names)
        }

    def get_tensor(self, crop: Image.Image):
        """Preprocess a crop into a numpy array for batch inference."""
        if crop.mode != "RGB":
            crop = crop.convert("RGB")

        img_tensor = TF.pil_to_tensor(crop)
        img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
        img_tensor = TF.resize(
            img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
        )
        img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
        return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0

    def classify_batch(self, batch):
        """Run inference on a batch of preprocessed numpy arrays."""
        tensor = torch.from_numpy(batch).to(self.device)
        with torch.no_grad():
            logits = self.model(tensor)
            probs = F.softmax(logits, dim=1).cpu().numpy()

        results = []
        for p in probs:
            classifications = [
                [self.class_names[i], float(p[i])]
                for i in range(len(self.class_names))
            ]
            results.append(classifications)
        return results