notRaphael's picture
Upload train.py with huggingface_hub
17bc682 verified
#!/usr/bin/env python3
"""
CSIRO Image2Biomass Prediction - Training Pipeline
====================================================
Multi-output regression: predicting 5 biomass targets from pasture images.
Targets: Dry_Green_g, Dry_Dead_g, Dry_Clover_g, GDM_g, Dry_Total_g
Metric: Weighted R² (weights: 0.1, 0.1, 0.1, 0.2, 0.5)
Architecture:
- Backbone: DINOv2 / ConvNeXt / EfficientNet (via timm)
- Head: MLP with LayerNorm, GELU, Dropout
- Loss: SmoothL1 + optional weighted R² + consistency regularizer
- Training: Mixed precision, gradient checkpointing, differential LR
Usage:
python train.py --data_dir /path/to/competition/data --backbone dinov2_base --epochs 50
"""
import os
import sys
import json
import time
import random
import argparse
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.cuda.amp import GradScaler, autocast
import timm
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
HAS_ALBUMENTATIONS = True
except ImportError:
HAS_ALBUMENTATIONS = False
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import KFold, StratifiedKFold
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ============================================================
# Constants
# ============================================================
TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5]
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
# Backbone configurations
BACKBONE_CONFIGS = {
'dinov2_small': {
'name': 'vit_small_patch14_dinov2.lvd142m',
'feat_dim': 384, 'native_size': 518, 'default_size': 224,
},
'dinov2_base': {
'name': 'vit_base_patch14_dinov2.lvd142m',
'feat_dim': 768, 'native_size': 518, 'default_size': 224,
},
'dinov2_large': {
'name': 'vit_large_patch14_dinov2.lvd142m',
'feat_dim': 1024, 'native_size': 518, 'default_size': 224,
},
'dinov2_base_reg': {
'name': 'vit_base_patch14_reg4_dinov2.lvd142m',
'feat_dim': 768, 'native_size': 518, 'default_size': 224,
},
'convnext_large': {
'name': 'convnext_large.fb_in22k_ft_in1k',
'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
},
'convnextv2_large': {
'name': 'convnextv2_large.fcmae_ft_in22k_in1k',
'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
},
'efficientnet_b4': {
'name': 'efficientnet_b4.ra2_in1k',
'feat_dim': 1792, 'native_size': 380, 'default_size': 320,
},
'swin_large': {
'name': 'swin_large_patch4_window7_224.ms_in22k_ft_in1k',
'feat_dim': 1536, 'native_size': 224, 'default_size': 224,
},
'eva02_large': {
'name': 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',
'feat_dim': 1024, 'native_size': 448, 'default_size': 448,
},
}
# ============================================================
# Dataset
# ============================================================
class BiomassDataset(Dataset):
"""Dataset for pasture biomass regression from images."""
def __init__(
self,
image_dir: str,
df: pd.DataFrame,
targets: Optional[pd.DataFrame] = None,
transform=None,
img_size: int = 224,
use_ndvi: bool = False,
log_transform: bool = True,
is_test: bool = False,
):
self.image_dir = Path(image_dir)
self.df = df.reset_index(drop=True)
self.targets = targets
self.transform = transform
self.img_size = img_size
self.use_ndvi = use_ndvi
self.log_transform = log_transform
self.is_test = is_test
# Pre-compute image paths
self.image_ids = self.df['image_id'].values if 'image_id' in self.df.columns else self.df.index.values
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_id = row['image_id'] if 'image_id' in row.index else row.name
# Load image
img_path = self.image_dir / f"{img_id}.jpg"
if not img_path.exists():
img_path = self.image_dir / f"{img_id}.png"
if not img_path.exists():
# Try without extension - search
candidates = list(self.image_dir.glob(f"{img_id}.*"))
if candidates:
img_path = candidates[0]
else:
raise FileNotFoundError(f"Image not found: {img_id}")
img = Image.open(img_path).convert('RGB')
img = np.array(img)
# Apply transforms
if self.transform is not None:
if HAS_ALBUMENTATIONS:
augmented = self.transform(image=img)
img_tensor = augmented['image']
else:
img = Image.fromarray(img)
img_tensor = self.transform(img)
else:
img = Image.fromarray(img)
img_tensor = transforms.ToTensor()(img)
result = {'image': img_tensor, 'image_id': str(img_id)}
# Add NDVI if available
if self.use_ndvi and 'NDVI' in self.df.columns:
ndvi = torch.tensor(row['NDVI'], dtype=torch.float32)
result['ndvi'] = ndvi
# Add targets if training
if self.targets is not None:
target_values = self.targets.iloc[idx][TARGET_COLS].values.astype(np.float32)
if self.log_transform:
target_values = np.log1p(target_values)
result['targets'] = torch.tensor(target_values, dtype=torch.float32)
return result
# ============================================================
# Augmentations
# ============================================================
def get_train_transforms(img_size: int = 224, aug_strength: str = 'medium'):
"""Get training augmentations."""
if HAS_ALBUMENTATIONS:
if aug_strength == 'light':
return A.Compose([
A.RandomResizedCrop(size=(img_size, img_size), scale=(0.7, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
elif aug_strength == 'medium':
return A.Compose([
A.RandomResizedCrop(size=(img_size, img_size), scale=(0.5, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.Transpose(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=25, val_shift_limit=15, p=0.4),
A.OneOf([
A.GaussianBlur(blur_limit=(3, 5)),
A.MotionBlur(blur_limit=5),
], p=0.15),
A.CoarseDropout(
num_holes_range=(1, 4),
hole_height_range=(int(img_size*0.05), int(img_size*0.15)),
hole_width_range=(int(img_size*0.05), int(img_size*0.15)),
p=0.2,
),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
else: # heavy
return A.Compose([
A.RandomResizedCrop(size=(img_size, img_size), scale=(0.4, 1.0)),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomRotate90(p=0.5),
A.Transpose(p=0.5),
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
A.RandomGamma(gamma_limit=(80, 120), p=0.3),
A.OneOf([
A.GaussianBlur(blur_limit=(3, 7)),
A.MotionBlur(blur_limit=7),
], p=0.2),
A.OneOf([
A.GaussNoise(p=1.0),
A.ISONoise(p=1.0),
], p=0.2),
A.CoarseDropout(
num_holes_range=(1, 8),
hole_height_range=(int(img_size*0.05), int(img_size*0.2)),
hole_width_range=(int(img_size*0.05), int(img_size*0.2)),
p=0.3,
),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
else:
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(0.5),
transforms.RandomVerticalFlip(0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def get_val_transforms(img_size: int = 224):
"""Get validation transforms."""
if HAS_ALBUMENTATIONS:
return A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
])
else:
return transforms.Compose([
transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])
def get_tta_transforms(img_size: int = 224, n_augments: int = 8):
"""Get TTA (test-time augmentation) transforms. Returns list of transforms."""
tta_list = [get_val_transforms(img_size)] # Original
if HAS_ALBUMENTATIONS:
# Add flipped/rotated versions
tta_list.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
tta_list.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
tta_list.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# Slightly different crops
for scale in [0.9, 1.0, 1.2]:
tta_list.append(A.Compose([
A.Resize(height=int(img_size * scale * 1.14), width=int(img_size * scale * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
return tta_list[:n_augments]
# ============================================================
# Model
# ============================================================
class BiomassModel(nn.Module):
"""
Multi-output regression model for biomass prediction.
Architecture:
- timm backbone (DINOv2, ConvNeXt, etc.)
- Optional auxiliary features (NDVI)
- MLP regression head with LayerNorm + GELU + Dropout
- Optional: separate heads per target for better specialization
"""
def __init__(
self,
backbone_name: str = 'vit_base_patch14_dinov2.lvd142m',
num_targets: int = 5,
hidden_dim: int = 512,
dropout: float = 0.3,
pretrained: bool = True,
img_size: int = 224,
use_ndvi: bool = False,
separate_heads: bool = False,
grad_checkpointing: bool = False,
):
super().__init__()
self.use_ndvi = use_ndvi
self.separate_heads = separate_heads
self.num_targets = num_targets
# Create backbone
kwargs = {'pretrained': pretrained, 'num_classes': 0}
if 'vit' in backbone_name or 'dinov2' in backbone_name:
kwargs['img_size'] = img_size
self.backbone = timm.create_model(backbone_name, **kwargs)
feat_dim = self.backbone.num_features
# Enable gradient checkpointing for memory efficiency
if grad_checkpointing:
if hasattr(self.backbone, 'set_grad_checkpointing'):
self.backbone.set_grad_checkpointing(True)
logger.info("Gradient checkpointing enabled")
# NDVI embedding
if use_ndvi:
self.ndvi_embed = nn.Sequential(
nn.Linear(1, 32),
nn.GELU(),
nn.Linear(32, 64),
)
feat_dim += 64
# Regression head(s)
if separate_heads:
# Separate MLP head per target - better specialization
self.heads = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(hidden_dim, 1),
)
for _ in range(num_targets)
])
else:
# Shared head - better when data is limited
self.head = nn.Sequential(
nn.LayerNorm(feat_dim),
nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout * 0.3),
nn.Linear(hidden_dim // 2, num_targets),
)
def forward(self, x, ndvi=None):
features = self.backbone(x)
if self.use_ndvi and ndvi is not None:
ndvi_feats = self.ndvi_embed(ndvi.unsqueeze(-1))
features = torch.cat([features, ndvi_feats], dim=-1)
if self.separate_heads:
outputs = [head(features) for head in self.heads]
return torch.cat(outputs, dim=-1)
else:
return self.head(features)
def get_param_groups(self, backbone_lr: float = 5e-5, head_lr: float = 1e-3):
"""Get parameter groups with differential learning rates."""
backbone_params = list(self.backbone.parameters())
head_params = [p for n, p in self.named_parameters() if 'backbone' not in n]
return [
{'params': backbone_params, 'lr': backbone_lr},
{'params': head_params, 'lr': head_lr},
]
# ============================================================
# Loss Functions
# ============================================================
class WeightedSmoothL1Loss(nn.Module):
"""SmoothL1 loss weighted by target importance."""
def __init__(self, target_weights=None, beta=1.0):
super().__init__()
self.beta = beta
if target_weights is None:
target_weights = TARGET_WEIGHTS
self.register_buffer('weights', torch.tensor(target_weights, dtype=torch.float32))
def forward(self, pred, target):
loss = F.smooth_l1_loss(pred, target, beta=self.beta, reduction='none') # [B, 5]
weighted = loss * self.weights.unsqueeze(0)
return weighted.mean()
class WeightedMSELoss(nn.Module):
"""MSE loss weighted by target importance."""
def __init__(self, target_weights=None):
super().__init__()
if target_weights is None:
target_weights = TARGET_WEIGHTS
self.register_buffer('weights', torch.tensor(target_weights, dtype=torch.float32))
def forward(self, pred, target):
loss = (pred - target) ** 2 # [B, 5]
weighted = loss * self.weights.unsqueeze(0)
return weighted.mean()
class ConsistencyLoss(nn.Module):
"""
Enforce structural constraint: Dry_Total_g ≈ Dry_Green_g + Dry_Dead_g + Dry_Clover_g
Only approximate because GDM includes all dry matter components.
"""
def __init__(self, weight=0.1):
super().__init__()
self.weight = weight
def forward(self, pred):
# pred columns: [Green, Dead, Clover, GDM, Total]
component_sum = pred[:, 0] + pred[:, 1] + pred[:, 2]
total = pred[:, 4]
return self.weight * F.mse_loss(component_sum, total)
class CombinedLoss(nn.Module):
"""Combined loss with SmoothL1 + consistency regularization."""
def __init__(self, smoothl1_weight=1.0, mse_weight=0.0, consistency_weight=0.1,
target_weights=None):
super().__init__()
self.smoothl1 = WeightedSmoothL1Loss(target_weights)
self.mse = WeightedMSELoss(target_weights) if mse_weight > 0 else None
self.consistency = ConsistencyLoss(consistency_weight) if consistency_weight > 0 else None
self.smoothl1_weight = smoothl1_weight
self.mse_weight = mse_weight
def forward(self, pred, target):
loss = self.smoothl1_weight * self.smoothl1(pred, target)
if self.mse is not None:
loss += self.mse_weight * self.mse(pred, target)
if self.consistency is not None:
loss += self.consistency(pred)
return loss
# ============================================================
# Label Distribution Smoothing (LDS)
# ============================================================
def get_lds_weights(labels: np.ndarray, bins: int = 100, kernel_size: int = 5, sigma: float = 2.0):
"""
Compute Label Distribution Smoothing (LDS) weights.
From "Delving into Deep Imbalanced Regression" (ICML 2021).
"""
from scipy.ndimage import convolve1d
# Use the most important target (Dry_Total_g) for weighting
if labels.ndim > 1:
labels = labels[:, -1] # Last column = Dry_Total_g
hist, bin_edges = np.histogram(labels, bins=bins)
kernel = np.exp(-np.linspace(-3, 3, kernel_size)**2 / (2 * sigma**2))
kernel /= kernel.sum()
smoothed = convolve1d(hist.astype(float), kernel, mode='reflect')
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
weights = 1.0 / (np.interp(labels, bin_centers, smoothed) + 1e-8)
weights = weights / weights.mean() # Normalize to mean=1
return weights
# ============================================================
# Metrics
# ============================================================
def compute_weighted_r2(preds: np.ndarray, targets: np.ndarray,
target_weights: List[float] = None) -> float:
"""
Compute the globally weighted R² (competition metric).
Args:
preds: [N, 5] predictions
targets: [N, 5] ground truth
target_weights: per-target weights (default: competition weights)
Returns:
Weighted R² score
"""
if target_weights is None:
target_weights = TARGET_WEIGHTS
n_samples = preds.shape[0]
n_targets = preds.shape[1]
# Expand to long format with per-row weights
all_preds = []
all_targets = []
all_weights = []
for j in range(n_targets):
all_preds.extend(preds[:, j])
all_targets.extend(targets[:, j])
all_weights.extend([target_weights[j]] * n_samples)
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)
all_weights = np.array(all_weights)
# Weighted mean
weighted_mean = np.sum(all_weights * all_targets) / np.sum(all_weights)
# SS_res and SS_tot
ss_res = np.sum(all_weights * (all_targets - all_preds) ** 2)
ss_tot = np.sum(all_weights * (all_targets - weighted_mean) ** 2)
r2 = 1.0 - ss_res / (ss_tot + 1e-8)
return r2
def compute_per_target_r2(preds: np.ndarray, targets: np.ndarray) -> Dict[str, float]:
"""Compute R² per target column."""
results = {}
for i, name in enumerate(TARGET_COLS):
ss_res = np.sum((targets[:, i] - preds[:, i]) ** 2)
ss_tot = np.sum((targets[:, i] - targets[:, i].mean()) ** 2)
r2 = 1.0 - ss_res / (ss_tot + 1e-8)
results[name] = r2
return results
# ============================================================
# Training Engine
# ============================================================
class Trainer:
"""Training engine with mixed precision, gradient accumulation, and k-fold."""
def __init__(self, args):
self.args = args
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.scaler = GradScaler() if self.device.type == 'cuda' else None
logger.info(f"Device: {self.device}")
if self.device.type == 'cuda':
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
def build_model(self):
"""Build model from args."""
backbone_cfg = BACKBONE_CONFIGS[self.args.backbone]
img_size = self.args.img_size or backbone_cfg['default_size']
model = BiomassModel(
backbone_name=backbone_cfg['name'],
num_targets=5,
hidden_dim=self.args.hidden_dim,
dropout=self.args.dropout,
pretrained=True,
img_size=img_size,
use_ndvi=self.args.use_ndvi,
separate_heads=self.args.separate_heads,
grad_checkpointing=self.args.grad_checkpointing,
)
return model.to(self.device)
def build_optimizer(self, model):
"""Build optimizer with differential learning rates."""
param_groups = model.get_param_groups(
backbone_lr=self.args.backbone_lr,
head_lr=self.args.head_lr,
)
if self.args.optimizer == 'adamw':
optimizer = torch.optim.AdamW(param_groups, weight_decay=self.args.weight_decay)
elif self.args.optimizer == 'sgd':
optimizer = torch.optim.SGD(param_groups, momentum=0.9, weight_decay=self.args.weight_decay)
else:
raise ValueError(f"Unknown optimizer: {self.args.optimizer}")
return optimizer
def build_scheduler(self, optimizer, num_training_steps):
"""Build learning rate scheduler."""
warmup_steps = int(num_training_steps * self.args.warmup_ratio)
if self.args.scheduler == 'cosine':
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
warmup = LinearLR(optimizer, start_factor=0.01, total_iters=warmup_steps)
cosine = CosineAnnealingLR(optimizer, T_max=num_training_steps - warmup_steps,
eta_min=self.args.min_lr)
scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[warmup_steps])
elif self.args.scheduler == 'plateau':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=0.5, patience=5, verbose=True)
else:
scheduler = None
return scheduler
def train_one_epoch(self, model, loader, optimizer, scheduler, loss_fn, epoch):
"""Train for one epoch."""
model.train()
running_loss = 0.0
num_samples = 0
for batch_idx, batch in enumerate(loader):
images = batch['image'].to(self.device)
targets = batch['targets'].to(self.device)
ndvi = batch.get('ndvi', None)
if ndvi is not None:
ndvi = ndvi.to(self.device)
# Forward pass with mixed precision
if self.scaler is not None:
with autocast(dtype=torch.float16):
preds = model(images, ndvi)
loss = loss_fn(preds, targets)
# Gradient accumulation
loss = loss / self.args.grad_accum_steps
self.scaler.scale(loss).backward()
if (batch_idx + 1) % self.args.grad_accum_steps == 0:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
self.scaler.step(optimizer)
self.scaler.update()
optimizer.zero_grad()
if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step()
else:
preds = model(images, ndvi)
loss = loss_fn(preds, targets)
loss = loss / self.args.grad_accum_steps
loss.backward()
if (batch_idx + 1) % self.args.grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
optimizer.step()
optimizer.zero_grad()
if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step()
running_loss += loss.item() * self.args.grad_accum_steps * images.size(0)
num_samples += images.size(0)
if (batch_idx + 1) % self.args.log_interval == 0:
avg_loss = running_loss / num_samples
lr = optimizer.param_groups[0]['lr']
logger.info(f"Epoch {epoch} [{batch_idx+1}/{len(loader)}] loss={avg_loss:.4f} lr={lr:.2e}")
return running_loss / num_samples
@torch.no_grad()
def validate(self, model, loader, loss_fn, log_transform=True):
"""Validate and compute metrics."""
model.eval()
all_preds = []
all_targets = []
running_loss = 0.0
num_samples = 0
for batch in loader:
images = batch['image'].to(self.device)
targets = batch['targets'].to(self.device)
ndvi = batch.get('ndvi', None)
if ndvi is not None:
ndvi = ndvi.to(self.device)
if self.scaler is not None:
with autocast(dtype=torch.float16):
preds = model(images, ndvi)
loss = loss_fn(preds, targets)
else:
preds = model(images, ndvi)
loss = loss_fn(preds, targets)
running_loss += loss.item() * images.size(0)
num_samples += images.size(0)
all_preds.append(preds.cpu().numpy())
all_targets.append(targets.cpu().numpy())
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)
# Inverse log transform for metric computation
if log_transform:
all_preds_orig = np.expm1(all_preds)
all_targets_orig = np.expm1(all_targets)
else:
all_preds_orig = all_preds
all_targets_orig = all_targets
# Clip negative predictions
all_preds_orig = np.clip(all_preds_orig, 0, None)
# Compute metrics
weighted_r2 = compute_weighted_r2(all_preds_orig, all_targets_orig)
per_target_r2 = compute_per_target_r2(all_preds_orig, all_targets_orig)
avg_loss = running_loss / num_samples
return {
'loss': avg_loss,
'weighted_r2': weighted_r2,
'per_target_r2': per_target_r2,
'preds': all_preds_orig,
'targets': all_targets_orig,
}
@torch.no_grad()
def predict(self, model, loader, log_transform=True, tta_transforms=None):
"""Generate predictions (inference)."""
model.eval()
all_preds = []
all_ids = []
for batch in loader:
images = batch['image'].to(self.device)
ndvi = batch.get('ndvi', None)
if ndvi is not None:
ndvi = ndvi.to(self.device)
if self.scaler is not None:
with autocast(dtype=torch.float16):
preds = model(images, ndvi)
else:
preds = model(images, ndvi)
all_preds.append(preds.cpu().numpy())
all_ids.extend(batch['image_id'])
all_preds = np.concatenate(all_preds, axis=0)
if log_transform:
all_preds = np.expm1(all_preds)
all_preds = np.clip(all_preds, 0, None)
return all_preds, all_ids
def train_fold(self, fold: int, train_df: pd.DataFrame, val_df: pd.DataFrame,
train_targets: pd.DataFrame, val_targets: pd.DataFrame,
image_dir: str):
"""Train a single fold."""
backbone_cfg = BACKBONE_CONFIGS[self.args.backbone]
img_size = self.args.img_size or backbone_cfg['default_size']
# Datasets
train_dataset = BiomassDataset(
image_dir=image_dir,
df=train_df,
targets=train_targets,
transform=get_train_transforms(img_size, self.args.aug_strength),
img_size=img_size,
use_ndvi=self.args.use_ndvi,
log_transform=self.args.log_transform,
)
val_dataset = BiomassDataset(
image_dir=image_dir,
df=val_df,
targets=val_targets,
transform=get_val_transforms(img_size),
img_size=img_size,
use_ndvi=self.args.use_ndvi,
log_transform=self.args.log_transform,
)
# Optional: LDS sample weights
if self.args.use_lds:
sample_weights = get_lds_weights(
train_targets[TARGET_COLS].values,
bins=self.args.lds_bins,
kernel_size=self.args.lds_kernel_size,
sigma=self.args.lds_sigma,
)
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(train_dataset),
replacement=True,
)
train_loader = DataLoader(
train_dataset, batch_size=self.args.batch_size,
sampler=sampler, num_workers=self.args.num_workers,
pin_memory=True, drop_last=True,
)
else:
train_loader = DataLoader(
train_dataset, batch_size=self.args.batch_size,
shuffle=True, num_workers=self.args.num_workers,
pin_memory=True, drop_last=True,
)
val_loader = DataLoader(
val_dataset, batch_size=self.args.batch_size * 2,
shuffle=False, num_workers=self.args.num_workers,
pin_memory=True,
)
# Model, optimizer, scheduler
model = self.build_model()
optimizer = self.build_optimizer(model)
num_training_steps = len(train_loader) * self.args.epochs // self.args.grad_accum_steps
scheduler = self.build_scheduler(optimizer, num_training_steps)
# Loss
loss_fn = CombinedLoss(
smoothl1_weight=1.0,
mse_weight=self.args.mse_weight,
consistency_weight=self.args.consistency_weight,
target_weights=TARGET_WEIGHTS,
).to(self.device)
# Training loop
best_r2 = -float('inf')
best_epoch = 0
patience_counter = 0
save_dir = Path(self.args.output_dir) / f"fold_{fold}"
save_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"\n{'='*60}")
logger.info(f"FOLD {fold}")
logger.info(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}")
logger.info(f"Backbone: {backbone_cfg['name']}, img_size: {img_size}")
logger.info(f"{'='*60}")
for epoch in range(1, self.args.epochs + 1):
t0 = time.time()
# Train
train_loss = self.train_one_epoch(model, train_loader, optimizer, scheduler, loss_fn, epoch)
# Validate
val_metrics = self.validate(model, val_loader, loss_fn, self.args.log_transform)
# LR scheduler step (for ReduceLROnPlateau)
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
scheduler.step(val_metrics['weighted_r2'])
elapsed = time.time() - t0
# Logging
logger.info(
f"Epoch {epoch}/{self.args.epochs} | "
f"train_loss={train_loss:.4f} | "
f"val_loss={val_metrics['loss']:.4f} | "
f"val_R²={val_metrics['weighted_r2']:.4f} | "
f"time={elapsed:.1f}s"
)
for name, r2 in val_metrics['per_target_r2'].items():
logger.info(f" {name}: R²={r2:.4f}")
# Save best model
if val_metrics['weighted_r2'] > best_r2:
best_r2 = val_metrics['weighted_r2']
best_epoch = epoch
patience_counter = 0
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'weighted_r2': best_r2,
'per_target_r2': val_metrics['per_target_r2'],
'args': vars(self.args),
}, save_dir / 'best_model.pth')
logger.info(f" *** New best R²={best_r2:.4f} (epoch {epoch}) ***")
else:
patience_counter += 1
# Early stopping
if patience_counter >= self.args.patience:
logger.info(f"Early stopping at epoch {epoch}. Best R²={best_r2:.4f} (epoch {best_epoch})")
break
# Load best model for final predictions
checkpoint = torch.load(save_dir / 'best_model.pth', map_location=self.device, weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
# OOF predictions
val_metrics = self.validate(model, val_loader, loss_fn, self.args.log_transform)
logger.info(f"\nFold {fold} Final: R²={val_metrics['weighted_r2']:.4f}")
return model, val_metrics
def train_kfold(self, df: pd.DataFrame, targets: pd.DataFrame, image_dir: str):
"""Train with K-Fold cross-validation."""
n_folds = self.args.n_folds
# Stratified bins based on Dry_Total_g
bins = pd.qcut(targets['Dry_Total_g'], q=min(10, n_folds), labels=False, duplicates='drop')
kf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=self.args.seed)
oof_preds = np.zeros((len(df), 5))
fold_scores = []
for fold, (train_idx, val_idx) in enumerate(kf.split(df, bins)):
train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]
train_targets = targets.iloc[train_idx]
val_targets = targets.iloc[val_idx]
model, val_metrics = self.train_fold(
fold, train_df, val_df, train_targets, val_targets, image_dir
)
oof_preds[val_idx] = val_metrics['preds']
fold_scores.append(val_metrics['weighted_r2'])
logger.info(f"Fold {fold} R²: {val_metrics['weighted_r2']:.4f}")
# Overall OOF score
targets_arr = targets[TARGET_COLS].values
overall_r2 = compute_weighted_r2(oof_preds, targets_arr)
logger.info(f"\n{'='*60}")
logger.info(f"Overall OOF R²: {overall_r2:.4f}")
logger.info(f"Per-fold R²: {[f'{s:.4f}' for s in fold_scores]}")
logger.info(f"Mean fold R²: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")
logger.info(f"{'='*60}")
# Save OOF predictions
oof_df = df[['image_id']].copy()
for i, col in enumerate(TARGET_COLS):
oof_df[col] = oof_preds[:, i]
oof_df.to_csv(Path(self.args.output_dir) / 'oof_predictions.csv', index=False)
return overall_r2, fold_scores
# ============================================================
# Data Loading Utilities
# ============================================================
def load_competition_data(data_dir: str):
"""
Load competition data. Expected structure:
data_dir/
train.csv
test.csv
train_images/
test_images/
sample_submission.csv
"""
data_dir = Path(data_dir)
# Load CSVs
train_df = pd.read_csv(data_dir / 'train.csv')
test_df = pd.read_csv(data_dir / 'test.csv')
if (data_dir / 'sample_submission.csv').exists():
sample_sub = pd.read_csv(data_dir / 'sample_submission.csv')
else:
sample_sub = None
# Determine image directories
train_img_dir = data_dir / 'train_images'
test_img_dir = data_dir / 'test_images'
if not train_img_dir.exists():
train_img_dir = data_dir / 'train'
if not test_img_dir.exists():
test_img_dir = data_dir / 'test'
logger.info(f"Train samples: {len(train_df)}")
logger.info(f"Test samples: {len(test_df)}")
logger.info(f"Train columns: {list(train_df.columns)}")
logger.info(f"Test columns: {list(test_df.columns)}")
# Check for target columns
available_targets = [c for c in TARGET_COLS if c in train_df.columns]
logger.info(f"Available targets: {available_targets}")
# Print target statistics
if available_targets:
logger.info("\nTarget statistics:")
for col in available_targets:
logger.info(f" {col}: mean={train_df[col].mean():.2f}, "
f"median={train_df[col].median():.2f}, "
f"std={train_df[col].std():.2f}, "
f"min={train_df[col].min():.2f}, "
f"max={train_df[col].max():.2f}")
return train_df, test_df, sample_sub, str(train_img_dir), str(test_img_dir)
def create_submission(preds: np.ndarray, image_ids: List[str], output_path: str):
"""
Create submission CSV in required format.
Args:
preds: [N, 5] predictions
image_ids: list of image IDs
output_path: path to save CSV
"""
rows = []
for i, img_id in enumerate(image_ids):
for j, target_name in enumerate(TARGET_COLS):
rows.append({
'sample_id': f"{img_id}__{target_name}",
'target': max(0, preds[i, j]), # Ensure non-negative
})
sub_df = pd.DataFrame(rows)
sub_df.to_csv(output_path, index=False)
logger.info(f"Submission saved to {output_path} ({len(sub_df)} rows)")
return sub_df
# ============================================================
# Seed and Reproducibility
# ============================================================
def set_seed(seed: int):
"""Set random seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# ============================================================
# Main
# ============================================================
def get_args():
parser = argparse.ArgumentParser(description='CSIRO Image2Biomass Training')
# Data
parser.add_argument('--data_dir', type=str, required=True, help='Competition data directory')
parser.add_argument('--output_dir', type=str, default='./output', help='Output directory')
# Model
parser.add_argument('--backbone', type=str, default='dinov2_base',
choices=list(BACKBONE_CONFIGS.keys()), help='Backbone architecture')
parser.add_argument('--img_size', type=int, default=None, help='Image size (default: backbone native)')
parser.add_argument('--hidden_dim', type=int, default=512, help='Hidden dim in MLP head')
parser.add_argument('--dropout', type=float, default=0.3, help='Dropout rate')
parser.add_argument('--separate_heads', action='store_true', help='Use separate heads per target')
parser.add_argument('--grad_checkpointing', action='store_true', help='Enable gradient checkpointing')
parser.add_argument('--use_ndvi', action='store_true', help='Use NDVI features')
# Training
parser.add_argument('--epochs', type=int, default=50, help='Max epochs')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--backbone_lr', type=float, default=5e-5, help='Backbone learning rate')
parser.add_argument('--head_lr', type=float, default=1e-3, help='Head learning rate')
parser.add_argument('--min_lr', type=float, default=1e-7, help='Min learning rate')
parser.add_argument('--weight_decay', type=float, default=1e-2, help='Weight decay')
parser.add_argument('--optimizer', type=str, default='adamw', choices=['adamw', 'sgd'])
parser.add_argument('--scheduler', type=str, default='cosine', choices=['cosine', 'plateau', 'none'])
parser.add_argument('--warmup_ratio', type=float, default=0.05, help='Warmup ratio')
parser.add_argument('--max_grad_norm', type=float, default=1.0, help='Max gradient norm')
parser.add_argument('--grad_accum_steps', type=int, default=1, help='Gradient accumulation steps')
parser.add_argument('--patience', type=int, default=10, help='Early stopping patience')
parser.add_argument('--log_interval', type=int, default=10, help='Log every N batches')
# Augmentation
parser.add_argument('--aug_strength', type=str, default='medium', choices=['light', 'medium', 'heavy'])
parser.add_argument('--log_transform', action='store_true', default=True, help='Log-transform targets')
parser.add_argument('--no_log_transform', action='store_true', help='Disable log-transform')
# LDS
parser.add_argument('--use_lds', action='store_true', help='Use Label Distribution Smoothing')
parser.add_argument('--lds_bins', type=int, default=100)
parser.add_argument('--lds_kernel_size', type=int, default=5)
parser.add_argument('--lds_sigma', type=float, default=2.0)
# Loss
parser.add_argument('--mse_weight', type=float, default=0.0, help='MSE loss weight')
parser.add_argument('--consistency_weight', type=float, default=0.1, help='Consistency loss weight')
# CV
parser.add_argument('--n_folds', type=int, default=5, help='Number of CV folds')
parser.add_argument('--fold', type=int, default=None, help='Train single fold (None=all)')
# Misc
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--mixed_precision', action='store_true', default=True)
args = parser.parse_args()
if args.no_log_transform:
args.log_transform = False
return args
def main():
args = get_args()
set_seed(args.seed)
# Load data
train_df, test_df, sample_sub, train_img_dir, test_img_dir = load_competition_data(args.data_dir)
# Separate features and targets
targets = train_df[TARGET_COLS].copy()
# Create output directory
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
# Save args
with open(Path(args.output_dir) / 'args.json', 'w') as f:
json.dump(vars(args), f, indent=2)
# Train
trainer = Trainer(args)
if args.fold is not None:
# Single fold training
from sklearn.model_selection import StratifiedKFold
bins = pd.qcut(targets['Dry_Total_g'], q=min(10, args.n_folds), labels=False, duplicates='drop')
kf = StratifiedKFold(n_splits=args.n_folds, shuffle=True, random_state=args.seed)
for fold_idx, (train_idx, val_idx) in enumerate(kf.split(train_df, bins)):
if fold_idx == args.fold:
train_fold_df = train_df.iloc[train_idx]
val_fold_df = train_df.iloc[val_idx]
train_targets = targets.iloc[train_idx]
val_targets = targets.iloc[val_idx]
model, val_metrics = trainer.train_fold(
args.fold, train_fold_df, val_fold_df,
train_targets, val_targets, train_img_dir
)
break
else:
# Full K-fold training
overall_r2, fold_scores = trainer.train_kfold(train_df, targets, train_img_dir)
logger.info("Training complete!")
if __name__ == '__main__':
main()