Addax-Data-Science commited on
Commit
9190f93
·
verified ·
1 Parent(s): 68ad8f8

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +326 -0
inference.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Peter van Lunteren, January 2026
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import sys
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import timm
12
+ import torch
13
+ import torch.nn as nn
14
+ from PIL import Image, ImageFile
15
+ from torch import tensor
16
+ from torchvision.transforms import InterpolationMode, transforms
17
+
18
+ # Don't freak out over truncated images
19
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
20
+
21
+ # DeepFaune model constants
22
+ CROP_SIZE = 182
23
+ BACKBONE = "vit_large_patch14_dinov2.lvd142m"
24
+
25
+ # DeepFaune class names (English)
26
+ # Source: https://plmlab.math.cnrs.fr/deepfaune/software/-/blob/master/classifTools.py
27
+ CLASS_NAMES_EN = [
28
+ 'bison', 'badger', 'ibex', 'beaver', 'red deer', 'chamois', 'cat', 'goat',
29
+ 'roe deer', 'dog', 'fallow deer', 'squirrel', 'moose', 'equid', 'genet',
30
+ 'wolverine', 'hedgehog', 'lagomorph', 'wolf', 'otter', 'lynx', 'marmot',
31
+ 'micromammal', 'mouflon', 'sheep', 'mustelid', 'bird', 'bear', 'nutria',
32
+ 'raccoon', 'fox', 'reindeer', 'wild boar', 'cow'
33
+ ]
34
+
35
+
36
+ class DeepFauneModel(nn.Module):
37
+ """
38
+ DeepFaune model wrapper.
39
+
40
+ Based on original DeepFaune classifTools.py Model class.
41
+ License: CeCILL (see header)
42
+ """
43
+
44
+ def __init__(self, model_path: Path):
45
+ """Initialize DeepFaune ViT model."""
46
+ super().__init__()
47
+ self.model_path = model_path
48
+ self.backbone = BACKBONE
49
+ self.nbclasses = len(CLASS_NAMES_EN)
50
+
51
+ # Create timm model with ViT-Large DINOv2 backbone
52
+ self.base_model = timm.create_model(
53
+ BACKBONE,
54
+ pretrained=False,
55
+ num_classes=self.nbclasses,
56
+ dynamic_img_size=True
57
+ )
58
+
59
+ def forward(self, input):
60
+ """Forward pass through model."""
61
+ return self.base_model(input)
62
+
63
+ def predict(self, data: torch.Tensor, device: torch.device) -> np.ndarray:
64
+ """
65
+ Run prediction with softmax.
66
+
67
+ Args:
68
+ data: Preprocessed image tensor
69
+ device: torch.device (cpu, cuda, or mps)
70
+
71
+ Returns:
72
+ Numpy array of softmax probabilities [num_classes]
73
+ """
74
+ self.eval()
75
+ self.to(device)
76
+
77
+ with torch.no_grad():
78
+ x = data.to(device)
79
+ output = self.forward(x).softmax(dim=1)
80
+ return output.cpu().numpy()[0] # Return first (and only) batch item
81
+
82
+ def load_weights(self, device: torch.device) -> None:
83
+ """
84
+ Load model weights from .pt file.
85
+
86
+ Based on original DeepFaune classifTools.py loadWeights method.
87
+
88
+ Args:
89
+ device: torch.device to load weights onto
90
+
91
+ Raises:
92
+ FileNotFoundError: If model file not found
93
+ RuntimeError: If loading fails
94
+ """
95
+ if not self.model_path.exists():
96
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
97
+
98
+ try:
99
+ params = torch.load(self.model_path, map_location=device)
100
+ args = params['args']
101
+
102
+ # Validate number of classes matches
103
+ if self.nbclasses != args['num_classes']:
104
+ raise RuntimeError(
105
+ f"Model has {args['num_classes']} classes but expected {self.nbclasses}"
106
+ )
107
+
108
+ self.backbone = args['backbone']
109
+ self.nbclasses = args['num_classes']
110
+ self.load_state_dict(params['state_dict'])
111
+
112
+ except Exception as e:
113
+ raise RuntimeError(f"Failed to load DeepFaune model weights: {e}") from e
114
+
115
+
116
+ class ModelInference:
117
+ """DeepFaune v1.3 inference implementation for AddaxAI-WebUI."""
118
+
119
+ def __init__(self, model_dir: Path, model_path: Path):
120
+ """
121
+ Initialize with model paths.
122
+
123
+ Args:
124
+ model_dir: Directory containing model files
125
+ model_path: Path to deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt file
126
+ """
127
+ self.model_dir = model_dir
128
+ self.model_path = model_path
129
+ self.model: DeepFauneModel | None = None
130
+ self.device: torch.device | None = None
131
+
132
+ # DeepFaune preprocessing transforms
133
+ # Based on classifTools.py Classifier.__init__
134
+ self.transforms = transforms.Compose([
135
+ transforms.Resize(
136
+ size=(CROP_SIZE, CROP_SIZE),
137
+ interpolation=InterpolationMode.BICUBIC,
138
+ max_size=None,
139
+ antialias=None
140
+ ),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(
143
+ mean=tensor([0.4850, 0.4560, 0.4060]),
144
+ std=tensor([0.2290, 0.2240, 0.2250])
145
+ )
146
+ ])
147
+
148
+ def check_gpu(self) -> bool:
149
+ """
150
+ Check GPU availability for DeepFaune (PyTorch).
151
+
152
+ Returns:
153
+ True if MPS (Apple Silicon) or CUDA available, False otherwise
154
+ """
155
+ # Check Apple MPS (Apple Silicon)
156
+ try:
157
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
158
+ return True
159
+ except Exception:
160
+ pass
161
+
162
+ # Check CUDA (NVIDIA)
163
+ return torch.cuda.is_available()
164
+
165
+ def load_model(self) -> None:
166
+ """
167
+ Load DeepFaune model into memory.
168
+
169
+ This creates the ViT-Large DINOv2 model and loads the trained weights.
170
+ Model is stored in self.model and reused for all subsequent classifications.
171
+
172
+ Raises:
173
+ RuntimeError: If model loading fails
174
+ FileNotFoundError: If model_path is invalid
175
+ """
176
+ # Determine device
177
+ if torch.cuda.is_available():
178
+ self.device = torch.device('cuda')
179
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available():
180
+ self.device = torch.device('mps')
181
+ else:
182
+ self.device = torch.device('cpu')
183
+
184
+ print(f"[DeepFaune] Loading model on device: {self.device}", file=sys.stderr, flush=True)
185
+
186
+ # Create and load model
187
+ self.model = DeepFauneModel(self.model_path)
188
+ self.model.load_weights(self.device)
189
+
190
+ print(
191
+ f"[DeepFaune] Model loaded: {BACKBONE} with {len(CLASS_NAMES_EN)} classes, "
192
+ f"resolution {CROP_SIZE}x{CROP_SIZE}",
193
+ file=sys.stderr, flush=True
194
+ )
195
+
196
+ def get_crop(
197
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
198
+ ) -> Image.Image:
199
+ """
200
+ Crop image using DeepFaune preprocessing.
201
+
202
+ DeepFaune uses a squared crop approach:
203
+ 1. Denormalize bbox coordinates
204
+ 2. Square the crop (max of width/height)
205
+ 3. Center the detection within the square
206
+ 4. Clip to image boundaries
207
+
208
+ Based on classify_detections.py get_crop function.
209
+
210
+ Args:
211
+ image: Full-resolution PIL Image
212
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
213
+
214
+ Returns:
215
+ Cropped PIL Image ready for classification
216
+
217
+ Raises:
218
+ ValueError: If bbox is invalid
219
+ """
220
+ width, height = image.size
221
+
222
+ # Denormalize bbox coordinates
223
+ xmin = int(round(bbox[0] * width))
224
+ ymin = int(round(bbox[1] * height))
225
+ xmax = int(round(bbox[2] * width)) + xmin
226
+ ymax = int(round(bbox[3] * height)) + ymin
227
+
228
+ xsize = xmax - xmin
229
+ ysize = ymax - ymin
230
+
231
+ if xsize <= 0 or ysize <= 0:
232
+ raise ValueError(f"Invalid bbox size: {xsize}x{ysize}")
233
+
234
+ # Square the crop by expanding smaller dimension
235
+ if xsize > ysize:
236
+ # Expand height to match width
237
+ expand = int((xsize - ysize) / 2)
238
+ ymin = ymin - expand
239
+ ymax = ymax + expand
240
+ elif ysize > xsize:
241
+ # Expand width to match height
242
+ expand = int((ysize - xsize) / 2)
243
+ xmin = xmin - expand
244
+ xmax = xmax + expand
245
+
246
+ # Clip to image boundaries
247
+ xmin_clipped = max(0, xmin)
248
+ ymin_clipped = max(0, ymin)
249
+ xmax_clipped = min(xmax, width)
250
+ ymax_clipped = min(ymax, height)
251
+
252
+ # Crop image
253
+ image_cropped = image.crop((xmin_clipped, ymin_clipped, xmax_clipped, ymax_clipped))
254
+
255
+ # Convert to RGB (DeepFaune requires RGB)
256
+ if image_cropped.mode != 'RGB':
257
+ image_cropped = image_cropped.convert('RGB')
258
+
259
+ return image_cropped
260
+
261
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
262
+ """
263
+ Run DeepFaune classification on cropped image.
264
+
265
+ Workflow:
266
+ 1. Preprocess crop with transforms (resize, normalize)
267
+ 2. Run model prediction with softmax
268
+ 3. Return all class probabilities (unsorted)
269
+
270
+ Args:
271
+ crop: Cropped PIL Image
272
+
273
+ Returns:
274
+ List of [class_name, confidence] lists for ALL classes.
275
+ Example: [["bison", 0.00001], ["badger", 0.00002], ["red deer", 0.99985], ...]
276
+ NOTE: Sorting by confidence is handled by classification_worker.py
277
+
278
+ Raises:
279
+ RuntimeError: If model not loaded or inference fails
280
+ """
281
+ if self.model is None or self.device is None:
282
+ raise RuntimeError("Model not loaded - call load_model() first")
283
+
284
+ try:
285
+ # Preprocess image (resize + normalize)
286
+ tensor_cropped = self.transforms(crop).unsqueeze(dim=0) # Add batch dimension
287
+
288
+ # Run prediction
289
+ confs = self.model.predict(tensor_cropped, self.device)
290
+
291
+ # Build list of [class_name, confidence] pairs
292
+ classifications = []
293
+ for i, class_name in enumerate(CLASS_NAMES_EN):
294
+ confidence = float(confs[i])
295
+ classifications.append([class_name, confidence])
296
+
297
+ # NOTE: Sorting by confidence is handled by classification_worker.py
298
+ return classifications
299
+
300
+ except Exception as e:
301
+ raise RuntimeError(f"DeepFaune classification failed: {e}") from e
302
+
303
+ def get_class_names(self) -> dict[str, str]:
304
+ """
305
+ Get mapping of class IDs to species names.
306
+
307
+ DeepFaune has 34 classes in a fixed order. We create a 1-indexed mapping
308
+ for JSON compatibility.
309
+
310
+ Returns:
311
+ Dict mapping class ID (1-indexed string) to species name
312
+ Example: {"1": "bison", "2": "badger", ..., "34": "cow"}
313
+
314
+ Raises:
315
+ RuntimeError: If model not loaded
316
+ """
317
+ if self.model is None:
318
+ raise RuntimeError("Model not loaded - call load_model() first")
319
+
320
+ # Build 1-indexed mapping
321
+ class_names = {}
322
+ for i, class_name in enumerate(CLASS_NAMES_EN):
323
+ class_id_str = str(i + 1) # 1-indexed
324
+ class_names[class_id_str] = class_name
325
+
326
+ return class_names