""" Inference script for GIF-JAP-v0-2 (Gifu Wildlife Classifier - Central Japan) This model classifies 13 species found in the Kuraiyama Experimental Forest (KEF) of Gifu University. Trained on ~23,000 camera trap images to support efficient monitoring of key wildlife species in central Japan (sika deer, wild boar, Asian black bear, Japanese serow). Model: Gifu Wildlife v0.2 Input: 224x224 RGB images Framework: PyTorch (ResNet50 with ImageNet initialization) Classes: 13 Japanese species and taxonomic groups Developer: Gifu University (Masaki Ando) Citation: https://jglobal.jst.go.jp/en/detail?JGLOBAL_ID=201902236803626745 License: MIT Info: https://github.com/gifu-wildlife/TrainingMdetClassifire Note: Prototype model trained on limited and imbalanced data from KEF region. Author: Peter van Lunteren Created: 2026-01-14 """ from __future__ import annotations import pathlib import platform import sys from pathlib import Path import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image, ImageFile from torchvision import transforms from torchvision.models import resnet # Don't freak out over truncated images ImageFile.LOAD_TRUNCATED_IMAGES = True # Make sure Windows-trained models work on Unix plt = platform.system() if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath class CustomResNet50(nn.Module): """ Custom ResNet50 model for Gifu Wildlife classification. Based on original gifu-wildlife classifier architecture. """ def __init__(self, num_classes: int, pretrained_path: Path | None = None, device_str: str = 'cpu'): """ Initialize ResNet50 model. Args: num_classes: Number of output classes pretrained_path: Optional path to ImageNet pretrained weights device_str: Device to load model on ('cpu', 'cuda', 'mps') """ super(CustomResNet50, self).__init__() # Load ResNet50 without pretrained weights self.model = resnet.resnet50(weights=None) # If ImageNet pretrained weights provided, load them if pretrained_path is not None and pretrained_path.exists(): state_dict = torch.load(pretrained_path, map_location=torch.device(device_str)) self.model.load_state_dict(state_dict) # Replace final classification layer with custom number of classes self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) def forward(self, x): """Forward pass through ResNet50.""" return self.model(x) class ModelInference: """Gifu Wildlife ResNet50 inference implementation for AddaxAI-WebUI.""" def __init__(self, model_dir: Path, model_path: Path): """ Initialize with model paths. Args: model_dir: Directory containing model files model_path: Path to gifu-wildlife_cls_resnet50_v0.2.1.pth file """ self.model_dir = model_dir self.model_path = model_path self.model: CustomResNet50 | None = None self.device: torch.device | None = None self.classes: pd.DataFrame | None = None # Gifu Wildlife preprocessing transforms # Simple resize to 224x224 + convert to tensor (no normalization) self.preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) def check_gpu(self) -> bool: """ Check GPU availability for Gifu Wildlife (PyTorch). Returns: True if MPS (Apple Silicon) or CUDA available, False otherwise """ # Check Apple MPS (Apple Silicon) try: if torch.backends.mps.is_built() and torch.backends.mps.is_available(): return True except Exception: pass # Check CUDA (NVIDIA) return torch.cuda.is_available() def load_model(self) -> None: """ Load Gifu Wildlife ResNet50 model into memory. This creates the ResNet50 model and loads the trained weights. Model is stored in self.model and reused for all subsequent classifications. Raises: RuntimeError: If model loading fails FileNotFoundError: If model_path or classes.csv is invalid """ # Determine device if torch.cuda.is_available(): device_str = 'cuda' elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available(): device_str = 'mps' else: device_str = 'cpu' self.device = torch.device(device_str) print(f"[GifuWildlife] Loading model on device: {self.device}", file=sys.stderr, flush=True) # Load classes.csv classes_path = self.model_dir / 'classes.csv' if not classes_path.exists(): raise FileNotFoundError( f"classes.csv not found: {classes_path}\n" f"Gifu Wildlife models require classes.csv in the model directory." ) try: self.classes = pd.read_csv(classes_path) except Exception as e: raise RuntimeError(f"Failed to load classes.csv: {e}") from e # Load ImageNet pretrained weights (optional) pretrained_weights_path = self.model_dir / 'resnet50-11ad3fa6.pth' # Create model self.model = CustomResNet50( num_classes=len(self.classes), pretrained_path=pretrained_weights_path if pretrained_weights_path.exists() else None, device_str=device_str ) # Load trained model checkpoint if not self.model_path.exists(): raise FileNotFoundError(f"Model file not found: {self.model_path}") try: checkpoint = torch.load(self.model_path, map_location=self.device) self.model.load_state_dict(checkpoint['state_dict']) self.model.to(self.device) self.model.eval() except Exception as e: raise RuntimeError(f"Failed to load Gifu Wildlife model: {e}") from e print( f"[GifuWildlife] Model loaded: ResNet50 with {len(self.classes)} classes, " f"resolution 224x224", file=sys.stderr, flush=True ) def get_crop( self, image: Image.Image, bbox: tuple[float, float, float, float] ) -> Image.Image: """ Crop image using Gifu Wildlife preprocessing. Simple direct crop with no padding or squaring: 1. Denormalize bbox coordinates 2. Clip to image boundaries 3. Crop directly Based on classify_detections.py get_crop function. Args: image: Full-resolution PIL Image bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] Returns: Cropped PIL Image ready for classification Raises: ValueError: If bbox is invalid """ buffer = 0 # No buffer/padding width, height = image.size # Denormalize bbox coordinates bbox1, bbox2, bbox3, bbox4 = bbox left = width * bbox1 top = height * bbox2 right = width * (bbox1 + bbox3) bottom = height * (bbox2 + bbox4) # Apply buffer and clip to image boundaries left = max(0, int(left) - buffer) top = max(0, int(top) - buffer) right = min(width, int(right) + buffer) bottom = min(height, int(bottom) + buffer) # Validate crop dimensions if right <= left or bottom <= top: raise ValueError(f"Invalid crop dimensions: ({left},{top}) to ({right},{bottom})") # Crop image image_cropped = image.crop((left, top, right, bottom)) return image_cropped def get_classification(self, crop: Image.Image) -> list[list[str, float]]: """ Run Gifu Wildlife classification on cropped image. Workflow: 1. Preprocess crop (resize + to tensor) 2. Run ResNet50 forward pass 3. Apply softmax to get probabilities 4. Return all class probabilities (unsorted) Args: crop: Cropped PIL Image Returns: List of [class_name, confidence] lists for ALL classes. Example: [["bear", 0.01], ["bird", 0.02], ["deer", 0.89], ...] NOTE: Sorting by confidence is handled by classification_worker.py Raises: RuntimeError: If model not loaded or inference fails """ if self.model is None or self.device is None or self.classes is None: raise RuntimeError("Model not loaded - call load_model() first") try: # Preprocess image input_tensor = self.preprocess(crop) input_batch = input_tensor.unsqueeze(0) # Add batch dimension input_batch = input_batch.to(self.device) # Run inference with torch.no_grad(): output = self.model(input_batch) probabilities = F.softmax(output, dim=1) probabilities_np = probabilities.cpu().detach().numpy() confidence_scores = probabilities_np[0] # Build list of [class_name, confidence] pairs classifications = [] for i in range(len(confidence_scores)): # Get class name from classes.csv (column 'Code' - common names) pred_class = self.classes.iloc[i]['Code'] pred_conf = float(confidence_scores[i]) classifications.append([pred_class, pred_conf]) # NOTE: Sorting by confidence is handled by classification_worker.py return classifications except Exception as e: raise RuntimeError(f"Gifu Wildlife classification failed: {e}") from e def get_class_names(self) -> dict[str, str]: """ Get mapping of class IDs to class names. Gifu Wildlife has 13 classes in order from classes.csv. We create a 1-indexed mapping for JSON compatibility. Returns: Dict mapping class ID (1-indexed string) to class name Example: {"1": "bear", "2": "bird", ..., "13": "squirrel"} Raises: RuntimeError: If classes not loaded """ if self.classes is None: raise RuntimeError("Classes not loaded - call load_model() first") # Build 1-indexed mapping from classes.csv class_names = {} for i in range(len(self.classes)): class_id_str = str(i + 1) # 1-indexed # Use 'Code' column (common names like "bear", "deer", "boar") class_name = self.classes.iloc[i]['Code'] class_names[class_id_str] = class_name return class_names