| """ |
| ModelInference for the addax-sppnet model family. |
| |
| Architecture: SpeciesNet GraphModule backbone (frozen) + a thin nn.Linear |
| head fine-tuned per region. Originally written for AddaxAI's legacy |
| classify_detections.py (Peter van Lunteren, 13 May 2025); ported here to |
| the WebUI's class-based ModelInference interface. |
| |
| Files expected in the model directory: |
| - <model_fname>.pt fine-tuned head checkpoint, e.g. final-20260317.pt |
| - <backbone>.pt frozen SpeciesNet backbone, one of: |
| - always_crop_99710272_22x8_v12_epoch_00148.pt |
| - full_image_88545560_22x8_v12_epoch_00153.pt |
| """ |
|
|
| from __future__ import annotations |
|
|
| |
| import pathlib |
| import platform |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision import transforms |
|
|
| if platform.system() != "Windows": |
| pathlib.WindowsPath = pathlib.PosixPath |
|
|
| |
| from PIL import ImageFile |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| _BACKBONE_FILENAMES = ( |
| "always_crop_99710272_22x8_v12_epoch_00148.pt", |
| "full_image_88545560_22x8_v12_epoch_00153.pt", |
| ) |
|
|
|
|
| def _load_fx_checkpoint(weights_path: Path, map_location: str = "cpu") -> nn.Module: |
| """Load a SpeciesNet onnx2torch GraphModule. |
| |
| The backbone is shipped as a torch.fx GraphModule. PyTorch 2.4+ |
| requires `reduce_graph_module` to be in the safe-globals allowlist |
| when loading with `weights_only=True`; older versions don't have |
| this concept. Try both paths. |
| """ |
| try: |
| from torch.fx.graph_module import reduce_graph_module |
| from torch.serialization import add_safe_globals |
| add_safe_globals([reduce_graph_module]) |
| except Exception: |
| pass |
|
|
| try: |
| obj = torch.load(weights_path, map_location=map_location, weights_only=True) |
| except Exception: |
| obj = torch.load(weights_path, map_location=map_location, weights_only=False) |
|
|
| if hasattr(obj, "state_dict") and hasattr(obj, "forward"): |
| return obj |
| raise ValueError(f"{weights_path} is not a torch.nn.Module GraphModule") |
|
|
|
|
| class _FXClassifier(nn.Module): |
| """SpeciesNet backbone (frozen) + linear head.""" |
|
|
| def __init__( |
| self, |
| backbone: nn.Module, |
| num_classes: int, |
| img_size: int = 480, |
| input_layout: str = "nhwc", |
| ) -> None: |
| super().__init__() |
| self.backbone = backbone |
| self.input_layout = input_layout.lower() |
|
|
| for p in self.backbone.parameters(): |
| p.requires_grad = False |
| self.backbone.eval() |
|
|
| |
| |
| with torch.no_grad(): |
| x = torch.zeros(1, 3, img_size, img_size) |
| if self.input_layout == "nhwc": |
| x = x.permute(0, 2, 3, 1).contiguous() |
| z = self.backbone(x) |
| z = self._pool(z) |
| in_features = z.shape[1] |
|
|
| self.head = nn.Linear(in_features, num_classes) |
|
|
| @staticmethod |
| def _pool(z: torch.Tensor) -> torch.Tensor: |
| if z.ndim == 4: |
| return F.adaptive_avg_pool2d(z, 1).flatten(1) |
| if z.ndim == 3: |
| return z.mean(dim=1) |
| return z.flatten(1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.input_layout == "nhwc": |
| x = x.permute(0, 2, 3, 1).contiguous() |
| z = self.backbone(x) |
| z = self._pool(z) |
| return self.head(z) |
|
|
|
|
| class ModelInference: |
| """ModelInference for the addax-sppnet family (SpeciesNet backbone + linear head).""" |
|
|
| def __init__(self, model_dir: Path, model_path: Path) -> None: |
| self.model_dir = Path(model_dir) |
| self.model_path = Path(model_path) |
| self.model: _FXClassifier | None = None |
| self.device: torch.device | None = None |
| self._class_names: list[str] = [] |
| self._preprocess: transforms.Compose | None = None |
|
|
| |
| |
| |
|
|
| def check_gpu(self) -> bool: |
| try: |
| if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| return True |
| except Exception: |
| pass |
| return torch.cuda.is_available() |
|
|
| def load_model(self) -> None: |
| if self.check_gpu(): |
| self.device = torch.device( |
| "mps" if torch.backends.mps.is_available() else "cuda" |
| ) |
| else: |
| self.device = torch.device("cpu") |
|
|
| |
| try: |
| checkpoint = torch.load( |
| self.model_path, map_location=self.device, weights_only=True |
| ) |
| except Exception: |
| checkpoint = torch.load( |
| self.model_path, map_location=self.device, weights_only=False |
| ) |
|
|
| |
| |
| backbone_path: Path | None = None |
| for name in _BACKBONE_FILENAMES: |
| candidate = self.model_dir / name |
| if candidate.exists(): |
| backbone_path = candidate |
| break |
| if backbone_path is None: |
| raise FileNotFoundError( |
| "Backbone weights not found. Expected one of " |
| f"{_BACKBONE_FILENAMES} in {self.model_dir}." |
| ) |
|
|
| backbone = _load_fx_checkpoint(backbone_path, map_location="cpu") |
| model = _FXClassifier( |
| backbone=backbone, |
| num_classes=checkpoint["num_classes"], |
| img_size=checkpoint["img_size"], |
| input_layout=checkpoint["input_layout"], |
| ) |
| model.load_state_dict(checkpoint["model"]) |
| self.model = model.to(self.device).eval() |
|
|
| self._class_names = list(checkpoint["class_names"]) |
|
|
| norm = checkpoint["normalize"] |
| img_size = checkpoint["img_size"] |
| self._preprocess = transforms.Compose([ |
| transforms.Resize((img_size, img_size), antialias=True), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=norm["mean"], std=norm["std"]), |
| ]) |
|
|
| def get_crop( |
| self, image: Image.Image, bbox: tuple[float, float, float, float] |
| ) -> Image.Image: |
| """Crop the bbox region. SpeciesNet head was trained on tight crops.""" |
| W, H = image.size |
| x, y, w, h = bbox |
| left = max(0, int(round(x * W))) |
| top = max(0, int(round(y * H))) |
| right = min(W, int(round((x + w) * W))) |
| bottom = min(H, int(round((y + h) * H))) |
| if right <= left or bottom <= top: |
| return image |
| return image.crop((left, top, right, bottom)) |
|
|
| def get_classification(self, crop: Image.Image) -> list[list]: |
| """Per-image inference. Returns [[name, prob], ...] for all classes.""" |
| assert self.model is not None and self._preprocess is not None |
| if crop.mode != "RGB": |
| crop = crop.convert("RGB") |
| tensor = self._preprocess(crop).unsqueeze(0).to(self.device) |
| with torch.no_grad(): |
| probs = F.softmax(self.model(tensor), dim=1).cpu().numpy()[0] |
| return [[self._class_names[i], float(probs[i])] for i in range(len(probs))] |
|
|
| def get_class_names(self) -> dict[str, str]: |
| """1-indexed mapping {id: class_name} for the output JSON.""" |
| return {str(i + 1): name for i, name in enumerate(self._class_names)} |
|
|
| |
| |
| |
|
|
| def get_tensor(self, crop: Image.Image) -> np.ndarray: |
| assert self._preprocess is not None |
| if crop.mode != "RGB": |
| crop = crop.convert("RGB") |
| return self._preprocess(crop).numpy() |
|
|
| def classify_batch(self, batch: np.ndarray) -> list[list[list]]: |
| assert self.model is not None |
| tensor = torch.from_numpy(batch).to(self.device) |
| with torch.no_grad(): |
| probs = F.softmax(self.model(tensor), dim=1).cpu().numpy() |
| return [ |
| [[self._class_names[j], float(p[j])] for j in range(len(p))] |
| for p in probs |
| ] |
|
|