csiro-image2biomass / train_ensemble.py
notRaphael's picture
Upload train_ensemble.py with huggingface_hub
f500b8c verified
#!/usr/bin/env python3
"""
CSIRO Image2Biomass - Multi-Backbone Ensemble Training
=======================================================
Trains multiple backbone architectures with K-fold CV,
then combines predictions for maximum performance.
Recommended ensemble:
1. DINOv2-Base (ViT, self-supervised, best generalization)
2. ConvNeXt-Large (CNN, strong spatial features)
3. EfficientNet-B4 (lightweight, fast, diverse predictions)
Usage:
python train_ensemble.py --data_dir /path/to/data --backbones dinov2_base convnext_large
"""
import os
import json
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
# Import from main training script
from train import (
Trainer, set_seed, load_competition_data, create_submission,
compute_weighted_r2, TARGET_COLS, BACKBONE_CONFIGS,
BiomassDataset, get_val_transforms
)
def train_backbone(args, backbone_name: str):
"""Train a single backbone with K-fold CV."""
print(f"\n{'#'*70}")
print(f"# Training backbone: {backbone_name}")
print(f"{'#'*70}")
# Update args for this backbone
args.backbone = backbone_name
cfg = BACKBONE_CONFIGS[backbone_name]
# Auto-adjust image size
if args.auto_img_size:
args.img_size = cfg['default_size']
# Auto-adjust batch size based on model size
if args.auto_batch_size:
feat_dim = cfg['feat_dim']
if feat_dim <= 384:
args.batch_size = 64
elif feat_dim <= 768:
args.batch_size = 32
elif feat_dim <= 1024:
args.batch_size = 16
else:
args.batch_size = 16
# Output directory per backbone
original_output = args.output_dir
args.output_dir = str(Path(original_output) / backbone_name)
# Load data
train_df, test_df, sample_sub, train_img_dir, test_img_dir = load_competition_data(args.data_dir)
targets = train_df[TARGET_COLS].copy()
# Train
set_seed(args.seed)
trainer = Trainer(args)
overall_r2, fold_scores = trainer.train_kfold(train_df, targets, train_img_dir)
# Save backbone info
info = {
'backbone': backbone_name,
'backbone_model': cfg['name'],
'img_size': args.img_size or cfg['default_size'],
'overall_r2': float(overall_r2),
'fold_scores': [float(s) for s in fold_scores],
'mean_r2': float(np.mean(fold_scores)),
'std_r2': float(np.std(fold_scores)),
}
with open(Path(args.output_dir) / 'backbone_info.json', 'w') as f:
json.dump(info, f, indent=2)
args.output_dir = original_output
return info
def ensemble_oof_predictions(output_dir: str, backbones: list, weights: dict = None):
"""Combine OOF predictions from multiple backbones and compute ensemble R²."""
output_dir = Path(output_dir)
all_oof = {}
for backbone in backbones:
oof_path = output_dir / backbone / 'oof_predictions.csv'
if oof_path.exists():
oof_df = pd.read_csv(oof_path)
all_oof[backbone] = oof_df[TARGET_COLS].values
print(f" {backbone}: loaded {len(oof_df)} OOF predictions")
if not all_oof:
print("No OOF predictions found!")
return None
# Default: equal weights
if weights is None:
weights = {b: 1.0 / len(all_oof) for b in all_oof}
# Normalize weights
total_w = sum(weights[b] for b in all_oof)
weights = {b: w / total_w for b, w in weights.items()}
# Weighted average
ensemble = np.zeros_like(list(all_oof.values())[0])
for backbone, oof in all_oof.items():
ensemble += weights[backbone] * oof
# Load ground truth from any backbone's OOF file
first_backbone = list(all_oof.keys())[0]
oof_df = pd.read_csv(output_dir / first_backbone / 'oof_predictions.csv')
# We need the actual targets - load from training data
# For now, compute individual R² and compare
print("\nEnsemble analysis:")
print(f" Backbones: {list(all_oof.keys())}")
print(f" Weights: {weights}")
return ensemble, weights
def optimize_ensemble_weights(output_dir: str, backbones: list, train_df: pd.DataFrame):
"""
Find optimal ensemble weights by maximizing OOF weighted R².
Uses simple grid search over weight combinations.
"""
from scipy.optimize import minimize
output_dir = Path(output_dir)
# Load OOF predictions and ground truth
all_oof = {}
targets = train_df[TARGET_COLS].values
for backbone in backbones:
oof_path = output_dir / backbone / 'oof_predictions.csv'
if oof_path.exists():
oof_df = pd.read_csv(oof_path)
all_oof[backbone] = oof_df[TARGET_COLS].values
backbone_list = list(all_oof.keys())
n_backbones = len(backbone_list)
if n_backbones == 0:
return None
if n_backbones == 1:
return {backbone_list[0]: 1.0}
def objective(w):
"""Negative weighted R² (to minimize)."""
w = np.abs(w) / np.abs(w).sum() # Normalize
ensemble = np.zeros_like(targets, dtype=np.float64)
for i, backbone in enumerate(backbone_list):
ensemble += w[i] * all_oof[backbone]
return -compute_weighted_r2(ensemble, targets)
# Grid search for 2-3 models
best_r2 = -float('inf')
best_weights = None
if n_backbones == 2:
for w1 in np.arange(0.1, 1.0, 0.05):
w2 = 1.0 - w1
r2 = -objective([w1, w2])
if r2 > best_r2:
best_r2 = r2
best_weights = [w1, w2]
elif n_backbones == 3:
for w1 in np.arange(0.1, 0.9, 0.05):
for w2 in np.arange(0.05, 0.9 - w1, 0.05):
w3 = 1.0 - w1 - w2
if w3 > 0.05:
r2 = -objective([w1, w2, w3])
if r2 > best_r2:
best_r2 = r2
best_weights = [w1, w2, w3]
else:
# Scipy optimize for 4+ models
x0 = np.ones(n_backbones) / n_backbones
result = minimize(objective, x0, method='Nelder-Mead')
best_weights = np.abs(result.x) / np.abs(result.x).sum()
best_r2 = -result.fun
weights = {backbone_list[i]: float(best_weights[i]) for i in range(n_backbones)}
print(f"\nOptimal ensemble weights (R²={best_r2:.4f}):")
for b, w in weights.items():
print(f" {b}: {w:.3f}")
return weights
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--output_dir', type=str, default='./ensemble_output')
parser.add_argument('--backbones', nargs='+',
default=['dinov2_base', 'convnext_large'],
choices=list(BACKBONE_CONFIGS.keys()))
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--n_folds', type=int, default=5)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--auto_img_size', action='store_true', default=True)
parser.add_argument('--auto_batch_size', action='store_true', default=True)
# Common training args
parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--backbone_lr', type=float, default=3e-5)
parser.add_argument('--head_lr', type=float, default=1e-3)
parser.add_argument('--min_lr', type=float, default=1e-7)
parser.add_argument('--weight_decay', type=float, default=1e-2)
parser.add_argument('--warmup_ratio', type=float, default=0.05)
parser.add_argument('--max_grad_norm', type=float, default=1.0)
parser.add_argument('--grad_accum_steps', type=int, default=2)
parser.add_argument('--patience', type=int, default=8)
parser.add_argument('--aug_strength', type=str, default='medium')
parser.add_argument('--log_transform', action='store_true', default=True)
parser.add_argument('--use_lds', action='store_true', default=True)
parser.add_argument('--mse_weight', type=float, default=0.0)
parser.add_argument('--consistency_weight', type=float, default=0.1)
parser.add_argument('--use_ndvi', action='store_true')
parser.add_argument('--separate_heads', action='store_true')
parser.add_argument('--grad_checkpointing', action='store_true', default=True)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--img_size', type=int, default=None)
parser.add_argument('--optimizer', type=str, default='adamw')
parser.add_argument('--scheduler', type=str, default='cosine')
parser.add_argument('--log_interval', type=int, default=10)
parser.add_argument('--fold', type=int, default=None)
parser.add_argument('--no_log_transform', action='store_true')
parser.add_argument('--lds_bins', type=int, default=100)
parser.add_argument('--lds_kernel_size', type=int, default=5)
parser.add_argument('--lds_sigma', type=float, default=2.0)
parser.add_argument('--mixed_precision', action='store_true', default=True)
args = parser.parse_args()
# Train each backbone
backbone_results = []
for backbone in args.backbones:
info = train_backbone(args, backbone)
backbone_results.append(info)
# Print summary
print(f"\n{'='*70}")
print("ENSEMBLE TRAINING SUMMARY")
print(f"{'='*70}")
for info in backbone_results:
print(f" {info['backbone']}: R²={info['overall_r2']:.4f} "
f"(mean={info['mean_r2']:.4f} ± {info['std_r2']:.4f})")
# Optimize ensemble weights
train_df, _, _, _, _ = load_competition_data(args.data_dir)
weights = optimize_ensemble_weights(
args.output_dir, args.backbones, train_df
)
# Save ensemble config
ensemble_config = {
'backbones': backbone_results,
'optimal_weights': weights,
}
with open(Path(args.output_dir) / 'ensemble_config.json', 'w') as f:
json.dump(ensemble_config, f, indent=2)
print(f"\nEnsemble config saved to {args.output_dir}/ensemble_config.json")
print("Done!")
if __name__ == '__main__':
main()