Addax-Data-Science commited on
Commit
2ba3552
·
verified ·
1 Parent(s): a92ebb3

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +292 -0
inference.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Peter van Lunteren, January 2026
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import pathlib
8
+ import platform
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ import pandas as pd
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from PIL import Image, ImageFile
17
+ from torchvision import transforms
18
+ from torchvision.models import resnet
19
+
20
+ # Don't freak out over truncated images
21
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
22
+
23
+ # Make sure Windows-trained models work on Unix
24
+ plt = platform.system()
25
+ if plt != 'Windows':
26
+ pathlib.WindowsPath = pathlib.PosixPath
27
+
28
+
29
+ class CustomResNet50(nn.Module):
30
+ """
31
+ Custom ResNet50 model for Gifu Wildlife classification.
32
+
33
+ Based on original gifu-wildlife classifier architecture.
34
+ """
35
+
36
+ def __init__(self, num_classes: int, pretrained_path: Path | None = None, device_str: str = 'cpu'):
37
+ """
38
+ Initialize ResNet50 model.
39
+
40
+ Args:
41
+ num_classes: Number of output classes
42
+ pretrained_path: Optional path to ImageNet pretrained weights
43
+ device_str: Device to load model on ('cpu', 'cuda', 'mps')
44
+ """
45
+ super(CustomResNet50, self).__init__()
46
+
47
+ # Load ResNet50 without pretrained weights
48
+ self.model = resnet.resnet50(weights=None)
49
+
50
+ # If ImageNet pretrained weights provided, load them
51
+ if pretrained_path is not None and pretrained_path.exists():
52
+ state_dict = torch.load(pretrained_path, map_location=torch.device(device_str))
53
+ self.model.load_state_dict(state_dict)
54
+
55
+ # Replace final classification layer with custom number of classes
56
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
57
+
58
+ def forward(self, x):
59
+ """Forward pass through ResNet50."""
60
+ return self.model(x)
61
+
62
+
63
+ class ModelInference:
64
+ """Gifu Wildlife ResNet50 inference implementation for AddaxAI-WebUI."""
65
+
66
+ def __init__(self, model_dir: Path, model_path: Path):
67
+ """
68
+ Initialize with model paths.
69
+
70
+ Args:
71
+ model_dir: Directory containing model files
72
+ model_path: Path to gifu-wildlife_cls_resnet50_v0.2.1.pth file
73
+ """
74
+ self.model_dir = model_dir
75
+ self.model_path = model_path
76
+ self.model: CustomResNet50 | None = None
77
+ self.device: torch.device | None = None
78
+ self.classes: pd.DataFrame | None = None
79
+
80
+ # Gifu Wildlife preprocessing transforms
81
+ # Simple resize to 224x224 + convert to tensor (no normalization)
82
+ self.preprocess = transforms.Compose([
83
+ transforms.Resize((224, 224)),
84
+ transforms.ToTensor(),
85
+ ])
86
+
87
+ def check_gpu(self) -> bool:
88
+ """
89
+ Check GPU availability for Gifu Wildlife (PyTorch).
90
+
91
+ Returns:
92
+ True if MPS (Apple Silicon) or CUDA available, False otherwise
93
+ """
94
+ # Check Apple MPS (Apple Silicon)
95
+ try:
96
+ if torch.backends.mps.is_built() and torch.backends.mps.is_available():
97
+ return True
98
+ except Exception:
99
+ pass
100
+
101
+ # Check CUDA (NVIDIA)
102
+ return torch.cuda.is_available()
103
+
104
+ def load_model(self) -> None:
105
+ """
106
+ Load Gifu Wildlife ResNet50 model into memory.
107
+
108
+ This creates the ResNet50 model and loads the trained weights.
109
+ Model is stored in self.model and reused for all subsequent classifications.
110
+
111
+ Raises:
112
+ RuntimeError: If model loading fails
113
+ FileNotFoundError: If model_path or classes.csv is invalid
114
+ """
115
+ # Determine device
116
+ if torch.cuda.is_available():
117
+ device_str = 'cuda'
118
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available():
119
+ device_str = 'mps'
120
+ else:
121
+ device_str = 'cpu'
122
+
123
+ self.device = torch.device(device_str)
124
+
125
+ print(f"[GifuWildlife] Loading model on device: {self.device}", file=sys.stderr, flush=True)
126
+
127
+ # Load classes.csv
128
+ classes_path = self.model_dir / 'classes.csv'
129
+ if not classes_path.exists():
130
+ raise FileNotFoundError(
131
+ f"classes.csv not found: {classes_path}\n"
132
+ f"Gifu Wildlife models require classes.csv in the model directory."
133
+ )
134
+
135
+ try:
136
+ self.classes = pd.read_csv(classes_path)
137
+ except Exception as e:
138
+ raise RuntimeError(f"Failed to load classes.csv: {e}") from e
139
+
140
+ # Load ImageNet pretrained weights (optional)
141
+ pretrained_weights_path = self.model_dir / 'resnet50-11ad3fa6.pth'
142
+
143
+ # Create model
144
+ self.model = CustomResNet50(
145
+ num_classes=len(self.classes),
146
+ pretrained_path=pretrained_weights_path if pretrained_weights_path.exists() else None,
147
+ device_str=device_str
148
+ )
149
+
150
+ # Load trained model checkpoint
151
+ if not self.model_path.exists():
152
+ raise FileNotFoundError(f"Model file not found: {self.model_path}")
153
+
154
+ try:
155
+ checkpoint = torch.load(self.model_path, map_location=self.device)
156
+ self.model.load_state_dict(checkpoint['state_dict'])
157
+ self.model.to(self.device)
158
+ self.model.eval()
159
+ except Exception as e:
160
+ raise RuntimeError(f"Failed to load Gifu Wildlife model: {e}") from e
161
+
162
+ print(
163
+ f"[GifuWildlife] Model loaded: ResNet50 with {len(self.classes)} classes, "
164
+ f"resolution 224x224",
165
+ file=sys.stderr, flush=True
166
+ )
167
+
168
+ def get_crop(
169
+ self, image: Image.Image, bbox: tuple[float, float, float, float]
170
+ ) -> Image.Image:
171
+ """
172
+ Crop image using Gifu Wildlife preprocessing.
173
+
174
+ Simple direct crop with no padding or squaring:
175
+ 1. Denormalize bbox coordinates
176
+ 2. Clip to image boundaries
177
+ 3. Crop directly
178
+
179
+ Based on classify_detections.py get_crop function.
180
+
181
+ Args:
182
+ image: Full-resolution PIL Image
183
+ bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
184
+
185
+ Returns:
186
+ Cropped PIL Image ready for classification
187
+
188
+ Raises:
189
+ ValueError: If bbox is invalid
190
+ """
191
+ buffer = 0 # No buffer/padding
192
+ width, height = image.size
193
+
194
+ # Denormalize bbox coordinates
195
+ bbox1, bbox2, bbox3, bbox4 = bbox
196
+ left = width * bbox1
197
+ top = height * bbox2
198
+ right = width * (bbox1 + bbox3)
199
+ bottom = height * (bbox2 + bbox4)
200
+
201
+ # Apply buffer and clip to image boundaries
202
+ left = max(0, int(left) - buffer)
203
+ top = max(0, int(top) - buffer)
204
+ right = min(width, int(right) + buffer)
205
+ bottom = min(height, int(bottom) + buffer)
206
+
207
+ # Validate crop dimensions
208
+ if right <= left or bottom <= top:
209
+ raise ValueError(f"Invalid crop dimensions: ({left},{top}) to ({right},{bottom})")
210
+
211
+ # Crop image
212
+ image_cropped = image.crop((left, top, right, bottom))
213
+
214
+ return image_cropped
215
+
216
+ def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
217
+ """
218
+ Run Gifu Wildlife classification on cropped image.
219
+
220
+ Workflow:
221
+ 1. Preprocess crop (resize + to tensor)
222
+ 2. Run ResNet50 forward pass
223
+ 3. Apply softmax to get probabilities
224
+ 4. Return all class probabilities (unsorted)
225
+
226
+ Args:
227
+ crop: Cropped PIL Image
228
+
229
+ Returns:
230
+ List of [class_name, confidence] lists for ALL classes.
231
+ Example: [["bear", 0.01], ["bird", 0.02], ["deer", 0.89], ...]
232
+ NOTE: Sorting by confidence is handled by classification_worker.py
233
+
234
+ Raises:
235
+ RuntimeError: If model not loaded or inference fails
236
+ """
237
+ if self.model is None or self.device is None or self.classes is None:
238
+ raise RuntimeError("Model not loaded - call load_model() first")
239
+
240
+ try:
241
+ # Preprocess image
242
+ input_tensor = self.preprocess(crop)
243
+ input_batch = input_tensor.unsqueeze(0) # Add batch dimension
244
+ input_batch = input_batch.to(self.device)
245
+
246
+ # Run inference
247
+ with torch.no_grad():
248
+ output = self.model(input_batch)
249
+ probabilities = F.softmax(output, dim=1)
250
+ probabilities_np = probabilities.cpu().detach().numpy()
251
+ confidence_scores = probabilities_np[0]
252
+
253
+ # Build list of [class_name, confidence] pairs
254
+ classifications = []
255
+ for i in range(len(confidence_scores)):
256
+ # Get class name from classes.csv (column 'Code' - common names)
257
+ pred_class = self.classes.iloc[i]['Code']
258
+ pred_conf = float(confidence_scores[i])
259
+ classifications.append([pred_class, pred_conf])
260
+
261
+ # NOTE: Sorting by confidence is handled by classification_worker.py
262
+ return classifications
263
+
264
+ except Exception as e:
265
+ raise RuntimeError(f"Gifu Wildlife classification failed: {e}") from e
266
+
267
+ def get_class_names(self) -> dict[str, str]:
268
+ """
269
+ Get mapping of class IDs to class names.
270
+
271
+ Gifu Wildlife has 13 classes in order from classes.csv.
272
+ We create a 1-indexed mapping for JSON compatibility.
273
+
274
+ Returns:
275
+ Dict mapping class ID (1-indexed string) to class name
276
+ Example: {"1": "bear", "2": "bird", ..., "13": "squirrel"}
277
+
278
+ Raises:
279
+ RuntimeError: If classes not loaded
280
+ """
281
+ if self.classes is None:
282
+ raise RuntimeError("Classes not loaded - call load_model() first")
283
+
284
+ # Build 1-indexed mapping from classes.csv
285
+ class_names = {}
286
+ for i in range(len(self.classes)):
287
+ class_id_str = str(i + 1) # 1-indexed
288
+ # Use 'Code' column (common names like "bear", "deer", "boar")
289
+ class_name = self.classes.iloc[i]['Code']
290
+ class_names[class_id_str] = class_name
291
+
292
+ return class_names