EUR-DF-v1-3 / inference.py
Addax-Data-Science's picture
Upload inference.py
55e53df verified
"""
Inference script for EUR-DF-v1-3 (DeepFaune v1.3 European Wildlife Classifier)
The DeepFaune initiative develops AI models to automatically classify species in camera-trap
images and videos. Led by CNRS (France) in collaboration with 50+ European partners.
Model: DeepFaune v1.3
Input: 182x182 RGB images
Framework: PyTorch (Vision Transformer - DINOv2)
Classes: 34 European species and taxonomic groups
Developer: The DeepFaune initiative (CNRS)
Citation: https://doi.org/10.1007/s10344-023-01742-7
License: CC BY-SA 4.0
Info: https://www.deepfaune.cnrs.fr/en/
Author: Peter van Lunteren
Created: 2026-01-14
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
import timm
import torch
import torch.nn as nn
from PIL import Image, ImageFile
from torch import tensor
from torchvision.transforms import InterpolationMode, transforms
# Don't freak out over truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True
# DeepFaune model constants
CROP_SIZE = 182
BACKBONE = "vit_large_patch14_dinov2.lvd142m"
# DeepFaune class names (English)
# Source: https://plmlab.math.cnrs.fr/deepfaune/software/-/blob/master/classifTools.py
CLASS_NAMES_EN = [
'bison', 'badger', 'ibex', 'beaver', 'red deer', 'chamois', 'cat', 'goat',
'roe deer', 'dog', 'fallow deer', 'squirrel', 'moose', 'equid', 'genet',
'wolverine', 'hedgehog', 'lagomorph', 'wolf', 'otter', 'lynx', 'marmot',
'micromammal', 'mouflon', 'sheep', 'mustelid', 'bird', 'bear', 'nutria',
'raccoon', 'fox', 'reindeer', 'wild boar', 'cow'
]
class DeepFauneModel(nn.Module):
"""
DeepFaune model wrapper.
Based on original DeepFaune classifTools.py Model class.
License: CeCILL (see header)
"""
def __init__(self, model_path: Path):
"""Initialize DeepFaune ViT model."""
super().__init__()
self.model_path = model_path
self.backbone = BACKBONE
self.nbclasses = len(CLASS_NAMES_EN)
# Create timm model with ViT-Large DINOv2 backbone
self.base_model = timm.create_model(
BACKBONE,
pretrained=False,
num_classes=self.nbclasses,
dynamic_img_size=True
)
def forward(self, input):
"""Forward pass through model."""
return self.base_model(input)
def predict(self, data: torch.Tensor, device: torch.device) -> np.ndarray:
"""
Run prediction with softmax.
Args:
data: Preprocessed image tensor
device: torch.device (cpu, cuda, or mps)
Returns:
Numpy array of softmax probabilities [num_classes]
"""
self.eval()
self.to(device)
with torch.no_grad():
x = data.to(device)
output = self.forward(x).softmax(dim=1)
return output.cpu().numpy()[0] # Return first (and only) batch item
def load_weights(self, device: torch.device) -> None:
"""
Load model weights from .pt file.
Based on original DeepFaune classifTools.py loadWeights method.
Args:
device: torch.device to load weights onto
Raises:
FileNotFoundError: If model file not found
RuntimeError: If loading fails
"""
if not self.model_path.exists():
raise FileNotFoundError(f"Model file not found: {self.model_path}")
try:
params = torch.load(self.model_path, map_location=device)
args = params['args']
# Validate number of classes matches
if self.nbclasses != args['num_classes']:
raise RuntimeError(
f"Model has {args['num_classes']} classes but expected {self.nbclasses}"
)
self.backbone = args['backbone']
self.nbclasses = args['num_classes']
self.load_state_dict(params['state_dict'])
except Exception as e:
raise RuntimeError(f"Failed to load DeepFaune model weights: {e}") from e
class ModelInference:
"""DeepFaune v1.3 inference implementation for AddaxAI-WebUI."""
def __init__(self, model_dir: Path, model_path: Path):
"""
Initialize with model paths.
Args:
model_dir: Directory containing model files
model_path: Path to deepfaune-vit_large_patch14_dinov2.lvd142m.v3.pt file
"""
self.model_dir = model_dir
self.model_path = model_path
self.model: DeepFauneModel | None = None
self.device: torch.device | None = None
# DeepFaune preprocessing transforms
# Based on classifTools.py Classifier.__init__
self.transforms = transforms.Compose([
transforms.Resize(
size=(CROP_SIZE, CROP_SIZE),
interpolation=InterpolationMode.BICUBIC,
max_size=None,
antialias=None
),
transforms.ToTensor(),
transforms.Normalize(
mean=tensor([0.4850, 0.4560, 0.4060]),
std=tensor([0.2290, 0.2240, 0.2250])
)
])
def check_gpu(self) -> bool:
"""
Check GPU availability for DeepFaune (PyTorch).
Returns:
True if MPS (Apple Silicon) or CUDA available, False otherwise
"""
# Check Apple MPS (Apple Silicon)
try:
if torch.backends.mps.is_built() and torch.backends.mps.is_available():
return True
except Exception:
pass
# Check CUDA (NVIDIA)
return torch.cuda.is_available()
def load_model(self) -> None:
"""
Load DeepFaune model into memory.
This creates the ViT-Large DINOv2 model and loads the trained weights.
Model is stored in self.model and reused for all subsequent classifications.
Raises:
RuntimeError: If model loading fails
FileNotFoundError: If model_path is invalid
"""
# Determine device
if torch.cuda.is_available():
self.device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_built() and torch.backends.mps.is_available():
self.device = torch.device('mps')
else:
self.device = torch.device('cpu')
print(f"[DeepFaune] Loading model on device: {self.device}", file=sys.stderr, flush=True)
# Create and load model
self.model = DeepFauneModel(self.model_path)
self.model.load_weights(self.device)
print(
f"[DeepFaune] Model loaded: {BACKBONE} with {len(CLASS_NAMES_EN)} classes, "
f"resolution {CROP_SIZE}x{CROP_SIZE}",
file=sys.stderr, flush=True
)
def get_crop(
self, image: Image.Image, bbox: tuple[float, float, float, float]
) -> Image.Image:
"""
Crop image using DeepFaune preprocessing.
DeepFaune uses a squared crop approach:
1. Denormalize bbox coordinates
2. Square the crop (max of width/height)
3. Center the detection within the square
4. Clip to image boundaries
Based on classify_detections.py get_crop function.
Args:
image: Full-resolution PIL Image
bbox: Normalized bounding box (x, y, width, height) in range [0.0, 1.0]
Returns:
Cropped PIL Image ready for classification
Raises:
ValueError: If bbox is invalid
"""
width, height = image.size
# Denormalize bbox coordinates
xmin = int(round(bbox[0] * width))
ymin = int(round(bbox[1] * height))
xmax = int(round(bbox[2] * width)) + xmin
ymax = int(round(bbox[3] * height)) + ymin
xsize = xmax - xmin
ysize = ymax - ymin
if xsize <= 0 or ysize <= 0:
raise ValueError(f"Invalid bbox size: {xsize}x{ysize}")
# Square the crop by expanding smaller dimension
if xsize > ysize:
# Expand height to match width
expand = int((xsize - ysize) / 2)
ymin = ymin - expand
ymax = ymax + expand
elif ysize > xsize:
# Expand width to match height
expand = int((ysize - xsize) / 2)
xmin = xmin - expand
xmax = xmax + expand
# Clip to image boundaries
xmin_clipped = max(0, xmin)
ymin_clipped = max(0, ymin)
xmax_clipped = min(xmax, width)
ymax_clipped = min(ymax, height)
# Crop image
image_cropped = image.crop((xmin_clipped, ymin_clipped, xmax_clipped, ymax_clipped))
# Convert to RGB (DeepFaune requires RGB)
if image_cropped.mode != 'RGB':
image_cropped = image_cropped.convert('RGB')
return image_cropped
def get_classification(self, crop: Image.Image) -> list[list[str, float]]:
"""
Run DeepFaune classification on cropped image.
Workflow:
1. Preprocess crop with transforms (resize, normalize)
2. Run model prediction with softmax
3. Return all class probabilities (unsorted)
Args:
crop: Cropped PIL Image
Returns:
List of [class_name, confidence] lists for ALL classes.
Example: [["bison", 0.00001], ["badger", 0.00002], ["red deer", 0.99985], ...]
NOTE: Sorting by confidence is handled by classification_worker.py
Raises:
RuntimeError: If model not loaded or inference fails
"""
if self.model is None or self.device is None:
raise RuntimeError("Model not loaded - call load_model() first")
try:
# Preprocess image (resize + normalize)
tensor_cropped = self.transforms(crop).unsqueeze(dim=0) # Add batch dimension
# Run prediction
confs = self.model.predict(tensor_cropped, self.device)
# Build list of [class_name, confidence] pairs
classifications = []
for i, class_name in enumerate(CLASS_NAMES_EN):
confidence = float(confs[i])
classifications.append([class_name, confidence])
# NOTE: Sorting by confidence is handled by classification_worker.py
return classifications
except Exception as e:
raise RuntimeError(f"DeepFaune classification failed: {e}") from e
def get_class_names(self) -> dict[str, str]:
"""
Get mapping of class IDs to species names.
DeepFaune has 34 classes in a fixed order. We create a 1-indexed mapping
for JSON compatibility.
Returns:
Dict mapping class ID (1-indexed string) to species name
Example: {"1": "bison", "2": "badger", ..., "34": "cow"}
Raises:
RuntimeError: If model not loaded
"""
if self.model is None:
raise RuntimeError("Model not loaded - call load_model() first")
# Build 1-indexed mapping
class_names = {}
for i, class_name in enumerate(CLASS_NAMES_EN):
class_id_str = str(i + 1) # 1-indexed
class_names[class_id_str] = class_name
return class_names