Addax-Data-Science commited on
Commit
1588be4
·
verified ·
1 Parent(s): fd6b168

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +246 -0
inference.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference script for SPECIESNET-v4-0-1-A-v1 (SpeciesNet classifier)
3
+
4
+ SpeciesNet is an image classifier designed to accelerate the review of images
5
+ from camera traps. Trained at Google using a large dataset of camera trap images
6
+ and an EfficientNet V2 M architecture. Classifies images into one of 2,498 labels
7
+ covering diverse animal species, higher-level taxa, and non-animal classes.
8
+
9
+ Model: SpeciesNet v4.0.1a (always_crop variant)
10
+ Input: 480x480 RGB images (NHWC layout)
11
+ Framework: PyTorch (torch.fx GraphModule)
12
+ Classes: 2,498
13
+ Developer: Google Research
14
+ Citation: https://doi.org/10.1049/cvi2.12318
15
+ License: https://github.com/google/cameratrapai/blob/main/LICENSE
16
+ Info: https://github.com/google/cameratrapai
17
+
18
+ Author: Peter van Lunteren
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import pathlib
24
+ import platform
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ import torchvision.transforms.functional as TF
30
+ from PIL import Image, ImageFile
31
+
32
+ # Don't freak out over truncated images
33
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
34
+
35
+ # Make sure Windows-trained models work on Unix
36
+ if platform.system() != "Windows":
37
+ pathlib.WindowsPath = pathlib.PosixPath
38
+
39
+ # Hardcoded model parameters for SpeciesNet v4.0.1a
40
+ LABELS_FILENAME = "always_crop_99710272_22x8_v12_epoch_00148.labels.txt"
41
+ IMG_SIZE = 480
42
+
43
+
44
+ class ModelInference:
45
+ """SpeciesNet inference implementation using the raw backbone .pt file."""
46
+
47
+ def __init__(self, model_dir: Path, model_path: Path):
48
+ """
49
+ Initialize with model paths.
50
+
51
+ Args:
52
+ model_dir: Directory containing model files
53
+ model_path: Path to always_crop_...pt file
54
+ """
55
+ self.model_dir = model_dir
56
+ self.model_path = model_path
57
+ self.model = None
58
+ self.device = None
59
+
60
+ # Parse labels file to get class names
61
+ labels_path = model_dir / LABELS_FILENAME
62
+ if not labels_path.exists():
63
+ raise FileNotFoundError(f"Labels file not found: {labels_path}")
64
+
65
+ self.class_names = []
66
+ seen_names: set[str] = set()
67
+ with open(labels_path) as f:
68
+ for line in f:
69
+ line = line.strip()
70
+ if not line:
71
+ continue
72
+ # Format: UUID;class;order;family;genus;species;common_name
73
+ parts = line.split(";")
74
+ if len(parts) >= 7:
75
+ common_name = parts[6]
76
+ else:
77
+ common_name = parts[-1]
78
+
79
+ # Empty or duplicate names cause ID collisions in the
80
+ # pipeline's reverse mapping. Fall back to the most
81
+ # specific taxonomy rank to create a unique label.
82
+ if not common_name or common_name in seen_names:
83
+ taxonomy = [p for p in parts[1:6] if p]
84
+ if taxonomy:
85
+ common_name = taxonomy[-1]
86
+
87
+ # If still duplicate, append the UUID prefix
88
+ if common_name in seen_names:
89
+ common_name = f"{common_name} ({parts[0][:8]})"
90
+
91
+ seen_names.add(common_name)
92
+ self.class_names.append(common_name)
93
+
94
+
95
+ def check_gpu(self) -> bool:
96
+ """Check GPU availability (Apple MPS or NVIDIA CUDA)."""
97
+ try:
98
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
99
+ return True
100
+ except Exception:
101
+ pass
102
+ return torch.cuda.is_available()
103
+
104
+ def load_model(self) -> None:
105
+ """
106
+ Load SpeciesNet GraphModule into memory.
107
+
108
+ The .pt file is a torch.fx GraphModule (EfficientNet V2 M backbone
109
+ with classification head). It expects NHWC input layout and outputs
110
+ logits directly with shape [batch, 2498].
111
+ """
112
+ if not self.model_path.exists():
113
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
114
+
115
+ # Detect device
116
+ try:
117
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
118
+ self.device = torch.device("mps")
119
+ elif torch.cuda.is_available():
120
+ self.device = torch.device("cuda")
121
+ else:
122
+ self.device = torch.device("cpu")
123
+ except Exception:
124
+ self.device = torch.device("cpu")
125
+
126
+ # Load the GraphModule (requires weights_only=False for FX deserialization)
127
+ self.model = torch.load(
128
+ self.model_path, map_location=self.device, weights_only=False
129
+ )
130
+ self.model.eval()
131
+
132
+ def get_crop(
133
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
134
+ ) -> Image.Image:
135
+ """
136
+ Crop image using normalized bounding box coordinates.
137
+
138
+ Matches SpeciesNet's preprocessing: crop using int() truncation
139
+ (not rounding) to match torchvision.transforms.functional.crop().
140
+
141
+ Args:
142
+ image: PIL Image (full resolution)
143
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
144
+
145
+ Returns:
146
+ Cropped PIL Image
147
+ """
148
+ W, H = image.size
149
+ x, y, w, h = bbox
150
+
151
+ left = int(x * W)
152
+ top = int(y * H)
153
+ crop_w = int(w * W)
154
+ crop_h = int(h * H)
155
+
156
+ if crop_w <= 0 or crop_h <= 0:
157
+ return image
158
+
159
+ return image.crop((left, top, left + crop_w, top + crop_h))
160
+
161
+ def get_classification(
162
+ self, crop: Image.Image
163
+ ) -> list[list[str | float]]:
164
+ """
165
+ Run SpeciesNet classification on a cropped image.
166
+
167
+ Args:
168
+ crop: Cropped and preprocessed PIL Image
169
+
170
+ Returns:
171
+ List of [class_name, confidence] lists for ALL classes.
172
+ Sorting by confidence is handled by classification_worker.py.
173
+
174
+ Raises:
175
+ RuntimeError: If model not loaded or inference fails
176
+ """
177
+ if self.model is None:
178
+ raise RuntimeError("Model not loaded, call load_model() first")
179
+
180
+ if crop.mode != "RGB":
181
+ crop = crop.convert("RGB")
182
+
183
+ # Match SpeciesNet's exact preprocessing pipeline:
184
+ # PIL -> CHW float32 [0,1] -> resize -> uint8 -> /255 -> HWC
185
+ img_tensor = TF.pil_to_tensor(crop)
186
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
187
+ img_tensor = TF.resize(
188
+ img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
189
+ )
190
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
191
+ # HWC float32 [0, 1] (matching speciesnet's img.arr / 255)
192
+ img_arr = img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
193
+ input_batch = torch.from_numpy(img_arr).unsqueeze(0).to(self.device)
194
+
195
+ with torch.no_grad():
196
+ logits = self.model(input_batch)
197
+ probabilities = F.softmax(logits, dim=1)
198
+
199
+ probs_np = probabilities.cpu().numpy()[0]
200
+
201
+ classifications = []
202
+ for i, prob in enumerate(probs_np):
203
+ classifications.append([self.class_names[i], float(prob)])
204
+
205
+ return classifications
206
+
207
+ def get_class_names(self) -> dict[str, str]:
208
+ """
209
+ Get mapping of class IDs to common names from the labels file.
210
+
211
+ Returns:
212
+ Dict mapping class ID (1-indexed string) to common name.
213
+ Example: {"1": "white/crandall's saddleback tamarin", "2": "western polecat", ...}
214
+ """
215
+ return {
216
+ str(i + 1): name for i, name in enumerate(self.class_names)
217
+ }
218
+
219
+ def get_tensor(self, crop: Image.Image):
220
+ """Preprocess a crop into a numpy array for batch inference."""
221
+ if crop.mode != "RGB":
222
+ crop = crop.convert("RGB")
223
+
224
+ img_tensor = TF.pil_to_tensor(crop)
225
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.float32)
226
+ img_tensor = TF.resize(
227
+ img_tensor, [IMG_SIZE, IMG_SIZE], antialias=False
228
+ )
229
+ img_tensor = TF.convert_image_dtype(img_tensor, torch.uint8)
230
+ return img_tensor.permute(1, 2, 0).numpy().astype("float32") / 255.0
231
+
232
+ def classify_batch(self, batch):
233
+ """Run inference on a batch of preprocessed numpy arrays."""
234
+ tensor = torch.from_numpy(batch).to(self.device)
235
+ with torch.no_grad():
236
+ logits = self.model(tensor)
237
+ probs = F.softmax(logits, dim=1).cpu().numpy()
238
+
239
+ results = []
240
+ for p in probs:
241
+ classifications = [
242
+ [self.class_names[i], float(p[i])]
243
+ for i in range(len(self.class_names))
244
+ ]
245
+ results.append(classifications)
246
+ return results