csiro-image2biomass / inference.py
notRaphael's picture
Upload inference.py with huggingface_hub
a1fccff verified
#!/usr/bin/env python3
"""
CSIRO Image2Biomass Prediction - Inference Pipeline
=====================================================
Generates submission.csv from trained model ensemble.
Supports:
- Single model or ensemble of fold models
- Test-Time Augmentation (TTA)
- Multi-backbone ensembling
- Post-processing (non-negative, consistency check)
For Kaggle submission notebook:
- All models must be saved in a Kaggle dataset
- No internet access during inference
"""
import os
import sys
import json
import glob
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
import timm
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
HAS_ALBUMENTATIONS = True
except ImportError:
HAS_ALBUMENTATIONS = False
from torchvision import transforms
from PIL import Image
# ============================================================
# Constants (must match training)
# ============================================================
TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
BACKBONE_CONFIGS = {
'dinov2_small': {'name': 'vit_small_patch14_dinov2.lvd142m', 'feat_dim': 384},
'dinov2_base': {'name': 'vit_base_patch14_dinov2.lvd142m', 'feat_dim': 768},
'dinov2_large': {'name': 'vit_large_patch14_dinov2.lvd142m', 'feat_dim': 1024},
'dinov2_base_reg': {'name': 'vit_base_patch14_reg4_dinov2.lvd142m', 'feat_dim': 768},
'convnext_large': {'name': 'convnext_large.fb_in22k_ft_in1k', 'feat_dim': 1536},
'convnextv2_large': {'name': 'convnextv2_large.fcmae_ft_in22k_in1k', 'feat_dim': 1536},
'efficientnet_b4': {'name': 'efficientnet_b4.ra2_in1k', 'feat_dim': 1792},
'swin_large': {'name': 'swin_large_patch4_window7_224.ms_in22k_ft_in1k', 'feat_dim': 1536},
'eva02_large': {'name': 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 'feat_dim': 1024},
}
# ============================================================
# Model (same as training)
# ============================================================
class BiomassModel(nn.Module):
def __init__(
self,
backbone_name: str = 'vit_base_patch14_dinov2.lvd142m',
num_targets: int = 5,
hidden_dim: int = 512,
dropout: float = 0.3,
pretrained: bool = False, # No pretrained for inference
img_size: int = 224,
use_ndvi: bool = False,
separate_heads: bool = False,
grad_checkpointing: bool = False,
):
super().__init__()
self.use_ndvi = use_ndvi
self.separate_heads = separate_heads
self.num_targets = num_targets
kwargs = {'pretrained': pretrained, 'num_classes': 0}
if 'vit' in backbone_name or 'dinov2' in backbone_name:
kwargs['img_size'] = img_size
self.backbone = timm.create_model(backbone_name, **kwargs)
feat_dim = self.backbone.num_features
if use_ndvi:
self.ndvi_embed = nn.Sequential(
nn.Linear(1, 32),
nn.GELU(),
nn.Linear(32, 64),
)
feat_dim += 64
if separate_heads:
self.heads = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(hidden_dim, 1),
)
for _ in range(num_targets)
])
else:
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout * 0.3),
nn.Linear(hidden_dim // 2, num_targets),
)
def forward(self, x, ndvi=None):
features = self.backbone(x)
if self.use_ndvi and ndvi is not None:
ndvi_feats = self.ndvi_embed(ndvi.unsqueeze(-1))
features = torch.cat([features, ndvi_feats], dim=-1)
if self.separate_heads:
outputs = [head(features) for head in self.heads]
return torch.cat(outputs, dim=-1)
else:
return self.head(features)
# ============================================================
# Dataset
# ============================================================
class TestDataset(Dataset):
def __init__(self, image_dir, df, transform=None, img_size=224, use_ndvi=False):
self.image_dir = Path(image_dir)
self.df = df.reset_index(drop=True)
self.transform = transform
self.img_size = img_size
self.use_ndvi = use_ndvi
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_id = row['image_id'] if 'image_id' in row.index else row.name
img_path = self.image_dir / f"{img_id}.jpg"
if not img_path.exists():
img_path = self.image_dir / f"{img_id}.png"
if not img_path.exists():
candidates = list(self.image_dir.glob(f"{img_id}.*"))
if candidates:
img_path = candidates[0]
img = Image.open(img_path).convert('RGB')
img = np.array(img)
if self.transform is not None:
if HAS_ALBUMENTATIONS:
augmented = self.transform(image=img)
img_tensor = augmented['image']
else:
img = Image.fromarray(img)
img_tensor = self.transform(img)
else:
img_tensor = transforms.ToTensor()(Image.fromarray(img))
result = {'image': img_tensor, 'image_id': str(img_id)}
if self.use_ndvi and 'NDVI' in self.df.columns:
result['ndvi'] = torch.tensor(row['NDVI'], dtype=torch.float32)
return result
# ============================================================
# Transforms
# ============================================================
def get_val_transforms(img_size=224):
if HAS_ALBUMENTATIONS:
return A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
else:
return transforms.Compose([
transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def get_tta_transforms(img_size=224):
"""Returns list of TTA transforms."""
ttas = [get_val_transforms(img_size)]
if HAS_ALBUMENTATIONS:
# Horizontal flip
ttas.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# Vertical flip
ttas.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# Both flips
ttas.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
return ttas
# ============================================================
# Inference
# ============================================================
def load_model(checkpoint_path: str, device: torch.device) -> BiomassModel:
"""Load a trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
args = checkpoint.get('args', {})
backbone_key = args.get('backbone', 'dinov2_base')
backbone_cfg = BACKBONE_CONFIGS[backbone_key]
img_size = args.get('img_size', None) or backbone_cfg.get('default_size', 224)
model = BiomassModel(
backbone_name=backbone_cfg['name'],
num_targets=5,
hidden_dim=args.get('hidden_dim', 512),
dropout=args.get('dropout', 0.3),
pretrained=False,
img_size=img_size,
use_ndvi=args.get('use_ndvi', False),
separate_heads=args.get('separate_heads', False),
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
print(f"Loaded model from {checkpoint_path}")
print(f" Backbone: {backbone_key}, img_size: {img_size}")
print(f" Best R²: {checkpoint.get('weighted_r2', 'N/A')}")
return model, args
@torch.no_grad()
def predict_single(model, loader, device, log_transform=True):
"""Predict with a single model."""
model.eval()
all_preds = []
all_ids = []
for batch in loader:
images = batch['image'].to(device)
ndvi = batch.get('ndvi', None)
if ndvi is not None:
ndvi = ndvi.to(device)
with autocast(dtype=torch.float16):
preds = model(images, ndvi)
all_preds.append(preds.cpu().numpy())
all_ids.extend(batch['image_id'])
all_preds = np.concatenate(all_preds, axis=0)
if log_transform:
all_preds = np.expm1(all_preds)
return all_preds, all_ids
def predict_with_tta(model, test_df, image_dir, device, img_size=224,
log_transform=True, batch_size=32, num_workers=4,
use_ndvi=False, n_tta=4):
"""Predict with Test-Time Augmentation."""
tta_transforms = get_tta_transforms(img_size)[:n_tta]
all_tta_preds = []
image_ids = None
for tta_idx, tfm in enumerate(tta_transforms):
dataset = TestDataset(image_dir, test_df, transform=tfm,
img_size=img_size, use_ndvi=use_ndvi)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
preds, ids = predict_single(model, loader, device, log_transform)
all_tta_preds.append(preds)
if image_ids is None:
image_ids = ids
print(f" TTA {tta_idx}: mean pred = {preds.mean():.2f}")
# Average TTA predictions
avg_preds = np.mean(all_tta_preds, axis=0)
return avg_preds, image_ids
def ensemble_predict(
model_dirs: List[str],
test_df: pd.DataFrame,
image_dir: str,
device: torch.device,
batch_size: int = 32,
num_workers: int = 4,
n_tta: int = 4,
model_weights: Optional[List[float]] = None,
):
"""
Ensemble predictions from multiple model checkpoints.
Args:
model_dirs: List of model checkpoint directories (each containing best_model.pth)
test_df: Test dataframe
image_dir: Test image directory
device: torch device
batch_size: Batch size for inference
num_workers: DataLoader workers
n_tta: Number of TTA augmentations
model_weights: Optional weights for each model (default: equal weights)
"""
all_preds = []
image_ids = None
for i, model_dir in enumerate(model_dirs):
model_dir = Path(model_dir)
# Find checkpoint
ckpt_path = model_dir / 'best_model.pth'
if not ckpt_path.exists():
# Search for any .pth file
ckpt_files = list(model_dir.glob('*.pth'))
if ckpt_files:
ckpt_path = ckpt_files[0]
else:
print(f"No checkpoint found in {model_dir}, skipping")
continue
print(f"\nModel {i+1}/{len(model_dirs)}: {ckpt_path}")
model, args = load_model(str(ckpt_path), device)
backbone_key = args.get('backbone', 'dinov2_base')
backbone_cfg = BACKBONE_CONFIGS[backbone_key]
img_size = args.get('img_size', None) or backbone_cfg.get('default_size', 224)
log_transform = args.get('log_transform', True)
use_ndvi = args.get('use_ndvi', False)
preds, ids = predict_with_tta(
model, test_df, image_dir, device,
img_size=img_size,
log_transform=log_transform,
batch_size=batch_size,
num_workers=num_workers,
use_ndvi=use_ndvi,
n_tta=n_tta,
)
all_preds.append(preds)
if image_ids is None:
image_ids = ids
# Free memory
del model
torch.cuda.empty_cache()
# Weighted ensemble
if model_weights is None:
model_weights = [1.0 / len(all_preds)] * len(all_preds)
else:
model_weights = np.array(model_weights) / sum(model_weights)
ensemble_preds = np.zeros_like(all_preds[0])
for pred, weight in zip(all_preds, model_weights):
ensemble_preds += weight * pred
return ensemble_preds, image_ids
# ============================================================
# Post-processing
# ============================================================
def postprocess_predictions(preds: np.ndarray) -> np.ndarray:
"""
Post-process predictions:
1. Clip to non-negative values
2. Ensure Dry_Total >= sum of components (optional soft constraint)
"""
preds = np.clip(preds, 0, None)
# Soft consistency: if Total < sum of components, adjust
component_sum = preds[:, 0] + preds[:, 1] + preds[:, 2]
total = preds[:, 4]
# Where total is less than component sum, nudge total up
mask = total < component_sum
if mask.any():
preds[mask, 4] = component_sum[mask] * 1.0 # Set total = component_sum
return preds
def create_submission(preds: np.ndarray, image_ids: List[str], output_path: str = 'submission.csv'):
"""Create submission CSV in competition format."""
rows = []
for i, img_id in enumerate(image_ids):
for j, target_name in enumerate(TARGET_COLS):
rows.append({
'sample_id': f"{img_id}__{target_name}",
'target': float(max(0, preds[i, j])),
})
sub_df = pd.DataFrame(rows)
sub_df.to_csv(output_path, index=False)
print(f"Submission saved: {output_path} ({len(sub_df)} rows)")
# Validation
n_images = len(image_ids)
expected_rows = n_images * 5
assert len(sub_df) == expected_rows, f"Expected {expected_rows} rows, got {len(sub_df)}"
# Print summary
for j, name in enumerate(TARGET_COLS):
col_preds = preds[:, j]
print(f" {name}: mean={col_preds.mean():.2f}, std={col_preds.std():.2f}, "
f"min={col_preds.min():.2f}, max={col_preds.max():.2f}")
return sub_df
# ============================================================
# Main inference entry point
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--model_dir', type=str, required=True,
help='Directory with fold_0/best_model.pth, fold_1/best_model.pth, etc.')
parser.add_argument('--output', type=str, default='submission.csv')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--n_tta', type=int, default=4)
parser.add_argument('--no_tta', action='store_true')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
# Load test data
data_dir = Path(args.data_dir)
test_df = pd.read_csv(data_dir / 'test.csv')
test_img_dir = data_dir / 'test_images'
if not test_img_dir.exists():
test_img_dir = data_dir / 'test'
print(f"Test samples: {len(test_df)}")
# Find all fold checkpoints
model_dir = Path(args.model_dir)
fold_dirs = sorted(model_dir.glob('fold_*'))
if not fold_dirs:
# Single model
fold_dirs = [model_dir]
print(f"Found {len(fold_dirs)} fold model(s)")
# Ensemble predict
n_tta = 1 if args.no_tta else args.n_tta
preds, image_ids = ensemble_predict(
model_dirs=[str(d) for d in fold_dirs],
test_df=test_df,
image_dir=str(test_img_dir),
device=device,
batch_size=args.batch_size,
num_workers=args.num_workers,
n_tta=n_tta,
)
# Post-process
preds = postprocess_predictions(preds)
# Create submission
create_submission(preds, image_ids, args.output)
print("\nDone!")
if __name__ == '__main__':
main()