|
|
""" |
|
|
Inference script for GIF-JAP-v0-2 (Gifu Wildlife Classifier - Central Japan) |
|
|
|
|
|
This model classifies 13 species found in the Kuraiyama Experimental Forest (KEF) of |
|
|
Gifu University. Trained on ~23,000 camera trap images to support efficient monitoring |
|
|
of key wildlife species in central Japan (sika deer, wild boar, Asian black bear, Japanese serow). |
|
|
|
|
|
Model: Gifu Wildlife v0.2 |
|
|
Input: 224x224 RGB images |
|
|
Framework: PyTorch (ResNet50 with ImageNet initialization) |
|
|
Classes: 13 Japanese species and taxonomic groups |
|
|
Developer: Gifu University (Masaki Ando) |
|
|
Citation: https://jglobal.jst.go.jp/en/detail?JGLOBAL_ID=201902236803626745 |
|
|
License: MIT |
|
|
Info: https://github.com/gifu-wildlife/TrainingMdetClassifire |
|
|
|
|
|
Note: Prototype model trained on limited and imbalanced data from KEF region. |
|
|
|
|
|
Author: Peter van Lunteren |
|
|
Created: 2026-01-14 |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import pathlib |
|
|
import platform |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image, ImageFile |
|
|
from torchvision import transforms |
|
|
from torchvision.models import resnet |
|
|
|
|
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
|
|
|
plt = platform.system() |
|
|
if plt != 'Windows': |
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
|
|
|
|
|
class CustomResNet50(nn.Module): |
|
|
""" |
|
|
Custom ResNet50 model for Gifu Wildlife classification. |
|
|
|
|
|
Based on original gifu-wildlife classifier architecture. |
|
|
""" |
|
|
|
|
|
def __init__(self, num_classes: int, pretrained_path: Path | None = None, device_str: str = 'cpu'): |
|
|
""" |
|
|
Initialize ResNet50 model. |
|
|
|
|
|
Args: |
|
|
num_classes: Number of output classes |
|
|
pretrained_path: Optional path to ImageNet pretrained weights |
|
|
device_str: Device to load model on ('cpu', 'cuda', 'mps') |
|
|
""" |
|
|
super(CustomResNet50, self).__init__() |
|
|
|
|
|
|
|
|
self.model = resnet.resnet50(weights=None) |
|
|
|
|
|
|
|
|
if pretrained_path is not None and pretrained_path.exists(): |
|
|
state_dict = torch.load(pretrained_path, map_location=torch.device(device_str)) |
|
|
self.model.load_state_dict(state_dict) |
|
|
|
|
|
|
|
|
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
"""Forward pass through ResNet50.""" |
|
|
return self.model(x) |
|
|
|
|
|
|
|
|
class ModelInference: |
|
|
"""Gifu Wildlife ResNet50 inference implementation for AddaxAI-WebUI.""" |
|
|
|
|
|
def __init__(self, model_dir: Path, model_path: Path): |
|
|
""" |
|
|
Initialize with model paths. |
|
|
|
|
|
Args: |
|
|
model_dir: Directory containing model files |
|
|
model_path: Path to gifu-wildlife_cls_resnet50_v0.2.1.pth file |
|
|
""" |
|
|
self.model_dir = model_dir |
|
|
self.model_path = model_path |
|
|
self.model: CustomResNet50 | None = None |
|
|
self.device: torch.device | None = None |
|
|
self.classes: pd.DataFrame | None = None |
|
|
|
|
|
|
|
|
|
|
|
self.preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
def check_gpu(self) -> bool: |
|
|
""" |
|
|
Check GPU availability for Gifu Wildlife (PyTorch). |
|
|
|
|
|
Returns: |
|
|
True if MPS (Apple Silicon) or CUDA available, False otherwise |
|
|
""" |
|
|
|
|
|
try: |
|
|
if torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
|
|
return True |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
return torch.cuda.is_available() |
|
|
|
|
|
def load_model(self) -> None: |
|
|
""" |
|
|
Load Gifu Wildlife ResNet50 model into memory. |
|
|
|
|
|
This creates the ResNet50 model and loads the trained weights. |
|
|
Model is stored in self.model and reused for all subsequent classifications. |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model loading fails |
|
|
FileNotFoundError: If model_path or classes.csv is invalid |
|
|
""" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
device_str = 'cuda' |
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
|
|
device_str = 'mps' |
|
|
else: |
|
|
device_str = 'cpu' |
|
|
|
|
|
self.device = torch.device(device_str) |
|
|
|
|
|
print(f"[GifuWildlife] Loading model on device: {self.device}", file=sys.stderr, flush=True) |
|
|
|
|
|
|
|
|
classes_path = self.model_dir / 'classes.csv' |
|
|
if not classes_path.exists(): |
|
|
raise FileNotFoundError( |
|
|
f"classes.csv not found: {classes_path}\n" |
|
|
f"Gifu Wildlife models require classes.csv in the model directory." |
|
|
) |
|
|
|
|
|
try: |
|
|
self.classes = pd.read_csv(classes_path) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load classes.csv: {e}") from e |
|
|
|
|
|
|
|
|
pretrained_weights_path = self.model_dir / 'resnet50-11ad3fa6.pth' |
|
|
|
|
|
|
|
|
self.model = CustomResNet50( |
|
|
num_classes=len(self.classes), |
|
|
pretrained_path=pretrained_weights_path if pretrained_weights_path.exists() else None, |
|
|
device_str=device_str |
|
|
) |
|
|
|
|
|
|
|
|
if not self.model_path.exists(): |
|
|
raise FileNotFoundError(f"Model file not found: {self.model_path}") |
|
|
|
|
|
try: |
|
|
checkpoint = torch.load(self.model_path, map_location=self.device) |
|
|
self.model.load_state_dict(checkpoint['state_dict']) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load Gifu Wildlife model: {e}") from e |
|
|
|
|
|
print( |
|
|
f"[GifuWildlife] Model loaded: ResNet50 with {len(self.classes)} classes, " |
|
|
f"resolution 224x224", |
|
|
file=sys.stderr, flush=True |
|
|
) |
|
|
|
|
|
def get_crop( |
|
|
self, image: Image.Image, bbox: tuple[float, float, float, float] |
|
|
) -> Image.Image: |
|
|
""" |
|
|
Crop image using Gifu Wildlife preprocessing. |
|
|
|
|
|
Simple direct crop with no padding or squaring: |
|
|
1. Denormalize bbox coordinates |
|
|
2. Clip to image boundaries |
|
|
3. Crop directly |
|
|
|
|
|
Based on classify_detections.py get_crop function. |
|
|
|
|
|
Args: |
|
|
image: Full-resolution PIL Image |
|
|
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0] |
|
|
|
|
|
Returns: |
|
|
Cropped PIL Image ready for classification |
|
|
|
|
|
Raises: |
|
|
ValueError: If bbox is invalid |
|
|
""" |
|
|
buffer = 0 |
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
bbox1, bbox2, bbox3, bbox4 = bbox |
|
|
left = width * bbox1 |
|
|
top = height * bbox2 |
|
|
right = width * (bbox1 + bbox3) |
|
|
bottom = height * (bbox2 + bbox4) |
|
|
|
|
|
|
|
|
left = max(0, int(left) - buffer) |
|
|
top = max(0, int(top) - buffer) |
|
|
right = min(width, int(right) + buffer) |
|
|
bottom = min(height, int(bottom) + buffer) |
|
|
|
|
|
|
|
|
if right <= left or bottom <= top: |
|
|
raise ValueError(f"Invalid crop dimensions: ({left},{top}) to ({right},{bottom})") |
|
|
|
|
|
|
|
|
image_cropped = image.crop((left, top, right, bottom)) |
|
|
|
|
|
return image_cropped |
|
|
|
|
|
def get_classification(self, crop: Image.Image) -> list[list[str, float]]: |
|
|
""" |
|
|
Run Gifu Wildlife classification on cropped image. |
|
|
|
|
|
Workflow: |
|
|
1. Preprocess crop (resize + to tensor) |
|
|
2. Run ResNet50 forward pass |
|
|
3. Apply softmax to get probabilities |
|
|
4. Return all class probabilities (unsorted) |
|
|
|
|
|
Args: |
|
|
crop: Cropped PIL Image |
|
|
|
|
|
Returns: |
|
|
List of [class_name, confidence] lists for ALL classes. |
|
|
Example: [["bear", 0.01], ["bird", 0.02], ["deer", 0.89], ...] |
|
|
NOTE: Sorting by confidence is handled by classification_worker.py |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If model not loaded or inference fails |
|
|
""" |
|
|
if self.model is None or self.device is None or self.classes is None: |
|
|
raise RuntimeError("Model not loaded - call load_model() first") |
|
|
|
|
|
try: |
|
|
|
|
|
input_tensor = self.preprocess(crop) |
|
|
input_batch = input_tensor.unsqueeze(0) |
|
|
input_batch = input_batch.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(input_batch) |
|
|
probabilities = F.softmax(output, dim=1) |
|
|
probabilities_np = probabilities.cpu().detach().numpy() |
|
|
confidence_scores = probabilities_np[0] |
|
|
|
|
|
|
|
|
classifications = [] |
|
|
for i in range(len(confidence_scores)): |
|
|
|
|
|
pred_class = self.classes.iloc[i]['Code'] |
|
|
pred_conf = float(confidence_scores[i]) |
|
|
classifications.append([pred_class, pred_conf]) |
|
|
|
|
|
|
|
|
return classifications |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Gifu Wildlife classification failed: {e}") from e |
|
|
|
|
|
def get_class_names(self) -> dict[str, str]: |
|
|
""" |
|
|
Get mapping of class IDs to class names. |
|
|
|
|
|
Gifu Wildlife has 13 classes in order from classes.csv. |
|
|
We create a 1-indexed mapping for JSON compatibility. |
|
|
|
|
|
Returns: |
|
|
Dict mapping class ID (1-indexed string) to class name |
|
|
Example: {"1": "bear", "2": "bird", ..., "13": "squirrel"} |
|
|
|
|
|
Raises: |
|
|
RuntimeError: If classes not loaded |
|
|
""" |
|
|
if self.classes is None: |
|
|
raise RuntimeError("Classes not loaded - call load_model() first") |
|
|
|
|
|
|
|
|
class_names = {} |
|
|
for i in range(len(self.classes)): |
|
|
class_id_str = str(i + 1) |
|
|
|
|
|
class_name = self.classes.iloc[i]['Code'] |
|
|
class_names[class_id_str] = class_name |
|
|
|
|
|
return class_names |
|
|
|