GIF-JAP-v0-2 / inference.py
Addax-Data-Science's picture
Upload inference.py
3bf0df7 verified
"""
Inference script for GIF-JAP-v0-2 (Gifu Wildlife Classifier - Central Japan)
This model classifies 13 species found in the Kuraiyama Experimental Forest (KEF) of
Gifu University. Trained on ~23,000 camera trap images to support efficient monitoring
of key wildlife species in central Japan (sika deer, wild boar, Asian black bear, Japanese serow).
Model: Gifu Wildlife v0.2
Input: 224x224 RGB images
Framework: PyTorch (ResNet50 with ImageNet initialization)
Classes: 13 Japanese species and taxonomic groups
Developer: Gifu University (Masaki Ando)
Citation: https://jglobal.jst.go.jp/en/detail?JGLOBAL_ID=201902236803626745
License: MIT
Info: https://github.com/gifu-wildlife/TrainingMdetClassifire
Note: Prototype model trained on limited and imbalanced data from KEF region.
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import pathlib
import platform
import sys
from pathlib import Path
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image, ImageFile
from torchvision import transforms
from torchvision.models import resnet
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# Make sure Windows-trained models work on Unix
plt = platform.system()
if plt != 'Windows':
pathlib.WindowsPath = pathlib.PosixPath
class CustomResNet50(nn.Module):
"""
Custom ResNet50 model for Gifu Wildlife classification.
Based on original gifu-wildlife classifier architecture.
"""
def __init__(self, num_classes: int, pretrained_path: Path | None = None, device_str: str = 'cpu'):
"""
Initialize ResNet50 model.
Args:
num_classes: Number of output classes
pretrained_path: Optional path to ImageNet pretrained weights
device_str: Device to load model on ('cpu', 'cuda', 'mps')
"""
super(CustomResNet50, self).__init__()
# Load ResNet50 without pretrained weights
self.model = resnet.resnet50(weights=None)
# If ImageNet pretrained weights provided, load them
if pretrained_path is not None and pretrained_path.exists():
state_dict = torch.load(pretrained_path, map_location=torch.device(device_str))
self.model.load_state_dict(state_dict)
# Replace final classification layer with custom number of classes
self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
def forward(self, x):
"""Forward pass through ResNet50."""
return self.model(x)
class ModelInference:
"""Gifu Wildlife ResNet50 inference implementation for AddaxAI-WebUI."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to gifu-wildlife_cls_resnet50_v0.2.1.pth file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model: CustomResNet50 | None = None
self.device: torch.device | None = None
self.classes: pd.DataFrame | None = None
# Gifu Wildlife preprocessing transforms
# Simple resize to 224x224 + convert to tensor (no normalization)
self.preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
def check_gpu(self) -> bool:
"""
Check GPU availability for Gifu Wildlife (PyTorch).
Returns:
True if MPS (Apple Silicon) or CUDA available, False otherwise
"""
# Check Apple MPS (Apple Silicon)
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
# Check CUDA (NVIDIA)
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load Gifu Wildlife ResNet50 model into memory.
This creates the ResNet50 model and loads the trained weights.
Model is stored in self.model and reused for all subsequent classifications.
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path or classes.csv is invalid
"""
# Determine device
if torch.cuda.is_available():
device_str = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available():
device_str = 'mps'
else:
device_str = 'cpu'
self.device = torch.device(device_str)
print(f"[GifuWildlife] Loading model on device: {self.device}", file=sys.stderr, flush=True)
# Load classes.csv
classes_path = self.model_dir / 'classes.csv'
if not classes_path.exists():
raise FileNotFoundError(
f"classes.csv not found: {classes_path}\n"
f"Gifu Wildlife models require classes.csv in the model directory."
)
try:
self.classes = pd.read_csv(classes_path)
except Exception as e:
raise RuntimeError(f"Failed to load classes.csv: {e}") from e
# Load ImageNet pretrained weights (optional)
pretrained_weights_path = self.model_dir / 'resnet50-11ad3fa6.pth'
# Create model
self.model = CustomResNet50(
num_classes=len(self.classes),
pretrained_path=pretrained_weights_path if pretrained_weights_path.exists() else None,
device_str=device_str
)
# Load trained model checkpoint
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
checkpoint = torch.load(self.model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['state_dict'])
self.model.to(self.device)
self.model.eval()
except Exception as e:
raise RuntimeError(f"Failed to load Gifu Wildlife model: {e}") from e
print(
f"[GifuWildlife] Model loaded: ResNet50 with {len(self.classes)} classes, "
f"resolution 224x224",
file=sys.stderr, flush=True
)
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using Gifu Wildlife preprocessing.
Simple direct crop with no padding or squaring:
1. Denormalize bbox coordinates
2. Clip to image boundaries
3. Crop directly
Based on classify_detections.py get_crop function.
Args:
image: Full-resolution PIL Image
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped PIL Image ready for classification
Raises:
ValueError: If bbox is invalid
"""
buffer = 0 # No buffer/padding
width, height = image.size
# Denormalize bbox coordinates
bbox1, bbox2, bbox3, bbox4 = bbox
left = width * bbox1
top = height * bbox2
right = width * (bbox1 + bbox3)
bottom = height * (bbox2 + bbox4)
# Apply buffer and clip to image boundaries
left = max(0, int(left) - buffer)
top = max(0, int(top) - buffer)
right = min(width, int(right) + buffer)
bottom = min(height, int(bottom) + buffer)
# Validate crop dimensions
if right <= left or bottom <= top:
raise ValueError(f"Invalid crop dimensions: ({left},{top}) to ({right},{bottom})")
# Crop image
image_cropped = image.crop((left, top, right, bottom))
return image_cropped
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run Gifu Wildlife classification on cropped image.
Workflow:
1. Preprocess crop (resize + to tensor)
2. Run ResNet50 forward pass
3. Apply softmax to get probabilities
4. Return all class probabilities (unsorted)
Args:
crop: Cropped PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes.
Example: [["bear", 0.01], ["bird", 0.02], ["deer", 0.89], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None or self.device is None or self.classes is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Preprocess image
input_tensor = self.preprocess(crop)
input_batch = input_tensor.unsqueeze(0) # Add batch dimension
input_batch = input_batch.to(self.device)
# Run inference
with torch.no_grad():
output = self.model(input_batch)
probabilities = F.softmax(output, dim=1)
probabilities_np = probabilities.cpu().detach().numpy()
confidence_scores = probabilities_np[0]
# Build list of [class_name, confidence] pairs
classifications = []
for i in range(len(confidence_scores)):
# Get class name from classes.csv (column 'Code' - common names)
pred_class = self.classes.iloc[i]['Code']
pred_conf = float(confidence_scores[i])
classifications.append([pred_class, pred_conf])
# NOTE: Sorting by confidence is handled by classification_worker.py
return classifications
except Exception as e:
raise RuntimeError(f"Gifu Wildlife classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to class names.
Gifu Wildlife has 13 classes in order from classes.csv.
We create a 1-indexed mapping for JSON compatibility.
Returns:
Dict mapping class ID (1-indexed string) to class name
Example: {"1": "bear", "2": "bird", ..., "13": "squirrel"}
Raises:
RuntimeError: If classes not loaded
"""
if self.classes is None:
raise RuntimeError("Classes not loaded - call load_model() first")
# Build 1-indexed mapping from classes.csv
class_names = {}
for i in range(len(self.classes)):
class_id_str = str(i + 1) # 1-indexed
# Use 'Code' column (common names like "bear", "deer", "boar")
class_name = self.classes.iloc[i]['Code']
class_names[class_id_str] = class_name
return class_names