""" ModelInference for the addax-sppnet model family. Architecture: SpeciesNet GraphModule backbone (frozen) + a thin nn.Linear head fine-tuned per region. Originally written for AddaxAI's legacy classify_detections.py (Peter van Lunteren, 13 May 2025); ported here to the WebUI's class-based ModelInference interface. Files expected in the model directory: - .pt fine-tuned head checkpoint, e.g. final-20260317.pt - .pt frozen SpeciesNet backbone, one of: - always_crop_99710272_22x8_v12_epoch_00148.pt - full_image_88545560_22x8_v12_epoch_00153.pt """ from __future__ import annotations # Allow loading checkpoints saved on a Windows runner on a POSIX machine. import pathlib import platform from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms if platform.system() != "Windows": pathlib.WindowsPath = pathlib.PosixPath # type: ignore[assignment] # Don't fail on truncated images during inference. from PIL import ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True _BACKBONE_FILENAMES = ( "always_crop_99710272_22x8_v12_epoch_00148.pt", "full_image_88545560_22x8_v12_epoch_00153.pt", ) def _load_fx_checkpoint(weights_path: Path, map_location: str = "cpu") -> nn.Module: """Load a SpeciesNet onnx2torch GraphModule. The backbone is shipped as a torch.fx GraphModule. PyTorch 2.4+ requires `reduce_graph_module` to be in the safe-globals allowlist when loading with `weights_only=True`; older versions don't have this concept. Try both paths. """ try: from torch.fx.graph_module import reduce_graph_module from torch.serialization import add_safe_globals add_safe_globals([reduce_graph_module]) except Exception: pass try: obj = torch.load(weights_path, map_location=map_location, weights_only=True) except Exception: obj = torch.load(weights_path, map_location=map_location, weights_only=False) if hasattr(obj, "state_dict") and hasattr(obj, "forward"): return obj raise ValueError(f"{weights_path} is not a torch.nn.Module GraphModule") class _FXClassifier(nn.Module): """SpeciesNet backbone (frozen) + linear head.""" def __init__( self, backbone: nn.Module, num_classes: int, img_size: int = 480, input_layout: str = "nhwc", ) -> None: super().__init__() self.backbone = backbone self.input_layout = input_layout.lower() for p in self.backbone.parameters(): p.requires_grad = False self.backbone.eval() # Probe the backbone to discover output feature size at this # img_size + layout combo, so the head matches exactly. with torch.no_grad(): x = torch.zeros(1, 3, img_size, img_size) if self.input_layout == "nhwc": x = x.permute(0, 2, 3, 1).contiguous() z = self.backbone(x) z = self._pool(z) in_features = z.shape[1] self.head = nn.Linear(in_features, num_classes) @staticmethod def _pool(z: torch.Tensor) -> torch.Tensor: if z.ndim == 4: return F.adaptive_avg_pool2d(z, 1).flatten(1) if z.ndim == 3: return z.mean(dim=1) return z.flatten(1) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.input_layout == "nhwc": x = x.permute(0, 2, 3, 1).contiguous() z = self.backbone(x) z = self._pool(z) return self.head(z) class ModelInference: """ModelInference for the addax-sppnet family (SpeciesNet backbone + linear head).""" def __init__(self, model_dir: Path, model_path: Path) -> None: self.model_dir = Path(model_dir) self.model_path = Path(model_path) self.model: _FXClassifier | None = None self.device: torch.device | None = None self._class_names: list[str] = [] self._preprocess: transforms.Compose | None = None # ------------------------------------------------------------------ # Required interface # ------------------------------------------------------------------ def check_gpu(self) -> bool: try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass return torch.cuda.is_available() def load_model(self) -> None: if self.check_gpu(): self.device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" ) else: self.device = torch.device("cpu") # Load fine-tuned head checkpoint. try: checkpoint = torch.load( self.model_path, map_location=self.device, weights_only=True ) except Exception: checkpoint = torch.load( self.model_path, map_location=self.device, weights_only=False ) # Resolve backbone path. The fine-tuned model ships alongside # one of two known backbone files, depending on the recipe. backbone_path: Path | None = None for name in _BACKBONE_FILENAMES: candidate = self.model_dir / name if candidate.exists(): backbone_path = candidate break if backbone_path is None: raise FileNotFoundError( "Backbone weights not found. Expected one of " f"{_BACKBONE_FILENAMES} in {self.model_dir}." ) backbone = _load_fx_checkpoint(backbone_path, map_location="cpu") model = _FXClassifier( backbone=backbone, num_classes=checkpoint["num_classes"], img_size=checkpoint["img_size"], input_layout=checkpoint["input_layout"], ) model.load_state_dict(checkpoint["model"]) self.model = model.to(self.device).eval() self._class_names = list(checkpoint["class_names"]) norm = checkpoint["normalize"] img_size = checkpoint["img_size"] self._preprocess = transforms.Compose([ transforms.Resize((img_size, img_size), antialias=True), transforms.ToTensor(), transforms.Normalize(mean=norm["mean"], std=norm["std"]), ]) def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """Crop the bbox region. SpeciesNet head was trained on tight crops.""" W, H = image.size x, y, w, h = bbox left = max(0, int(round(x * W))) top = max(0, int(round(y * H))) right = min(W, int(round((x + w) * W))) bottom = min(H, int(round((y + h) * H))) if right <= left or bottom <= top: return image return image.crop((left, top, right, bottom)) def get_classification(self, crop: Image.Image) -> list[list]: """Per-image inference. Returns [[name, prob], ...] for all classes.""" assert self.model is not None and self._preprocess is not None if crop.mode != "RGB": crop = crop.convert("RGB") tensor = self._preprocess(crop).unsqueeze(0).to(self.device) with torch.no_grad(): probs = F.softmax(self.model(tensor), dim=1).cpu().numpy()[0] return [[self._class_names[i], float(probs[i])] for i in range(len(probs))] def get_class_names(self) -> dict[str, str]: """1-indexed mapping {id: class_name} for the output JSON.""" return {str(i + 1): name for i, name in enumerate(self._class_names)} # ------------------------------------------------------------------ # Optional batch interface (5-15x GPU speedup vs per-crop calls) # ------------------------------------------------------------------ def get_tensor(self, crop: Image.Image) -> np.ndarray: assert self._preprocess is not None if crop.mode != "RGB": crop = crop.convert("RGB") return self._preprocess(crop).numpy() def classify_batch(self, batch: np.ndarray) -> list[list[list]]: assert self.model is not None tensor = torch.from_numpy(batch).to(self.device) with torch.no_grad(): probs = F.softmax(self.model(tensor), dim=1).cpu().numpy() return [ [[self._class_names[j], float(p[j])] for j in range(len(p))] for p in probs ]