| |
| """ |
| CSIRO Image2Biomass Prediction - Inference Pipeline |
| ===================================================== |
| Generates submission.csv from trained model ensemble. |
| |
| Supports: |
| - Single model or ensemble of fold models |
| - Test-Time Augmentation (TTA) |
| - Multi-backbone ensembling |
| - Post-processing (non-negative, consistency check) |
| |
| For Kaggle submission notebook: |
| - All models must be saved in a Kaggle dataset |
| - No internet access during inference |
| """ |
|
|
| import os |
| import sys |
| import json |
| import glob |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from torch.cuda.amp import autocast |
| import timm |
|
|
| try: |
| import albumentations as A |
| from albumentations.pytorch import ToTensorV2 |
| HAS_ALBUMENTATIONS = True |
| except ImportError: |
| HAS_ALBUMENTATIONS = False |
| from torchvision import transforms |
|
|
| from PIL import Image |
|
|
| |
| |
| |
| TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g'] |
| TARGET_WEIGHTS = [0.1, 0.1, 0.1, 0.2, 0.5] |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) |
| IMAGENET_STD = (0.229, 0.224, 0.225) |
|
|
| BACKBONE_CONFIGS = { |
| 'dinov2_small': {'name': 'vit_small_patch14_dinov2.lvd142m', 'feat_dim': 384}, |
| 'dinov2_base': {'name': 'vit_base_patch14_dinov2.lvd142m', 'feat_dim': 768}, |
| 'dinov2_large': {'name': 'vit_large_patch14_dinov2.lvd142m', 'feat_dim': 1024}, |
| 'dinov2_base_reg': {'name': 'vit_base_patch14_reg4_dinov2.lvd142m', 'feat_dim': 768}, |
| 'convnext_large': {'name': 'convnext_large.fb_in22k_ft_in1k', 'feat_dim': 1536}, |
| 'convnextv2_large': {'name': 'convnextv2_large.fcmae_ft_in22k_in1k', 'feat_dim': 1536}, |
| 'efficientnet_b4': {'name': 'efficientnet_b4.ra2_in1k', 'feat_dim': 1792}, |
| 'swin_large': {'name': 'swin_large_patch4_window7_224.ms_in22k_ft_in1k', 'feat_dim': 1536}, |
| 'eva02_large': {'name': 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 'feat_dim': 1024}, |
| } |
|
|
|
|
| |
| |
| |
| class BiomassModel(nn.Module): |
| def __init__( |
| self, |
| backbone_name: str = 'vit_base_patch14_dinov2.lvd142m', |
| num_targets: int = 5, |
| hidden_dim: int = 512, |
| dropout: float = 0.3, |
| pretrained: bool = False, |
| img_size: int = 224, |
| use_ndvi: bool = False, |
| separate_heads: bool = False, |
| grad_checkpointing: bool = False, |
| ): |
| super().__init__() |
| self.use_ndvi = use_ndvi |
| self.separate_heads = separate_heads |
| self.num_targets = num_targets |
| |
| kwargs = {'pretrained': pretrained, 'num_classes': 0} |
| if 'vit' in backbone_name or 'dinov2' in backbone_name: |
| kwargs['img_size'] = img_size |
| |
| self.backbone = timm.create_model(backbone_name, **kwargs) |
| feat_dim = self.backbone.num_features |
| |
| if use_ndvi: |
| self.ndvi_embed = nn.Sequential( |
| nn.Linear(1, 32), |
| nn.GELU(), |
| nn.Linear(32, 64), |
| ) |
| feat_dim += 64 |
| |
| if separate_heads: |
| self.heads = nn.ModuleList([ |
| nn.Sequential( |
| nn.LayerNorm(feat_dim), |
| nn.Dropout(dropout), |
| nn.Linear(feat_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout * 0.5), |
| nn.Linear(hidden_dim, 1), |
| ) |
| for _ in range(num_targets) |
| ]) |
| else: |
| self.head = nn.Sequential( |
| nn.LayerNorm(feat_dim), |
| nn.Dropout(dropout), |
| nn.Linear(feat_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout * 0.5), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Dropout(dropout * 0.3), |
| nn.Linear(hidden_dim // 2, num_targets), |
| ) |
| |
| def forward(self, x, ndvi=None): |
| features = self.backbone(x) |
| if self.use_ndvi and ndvi is not None: |
| ndvi_feats = self.ndvi_embed(ndvi.unsqueeze(-1)) |
| features = torch.cat([features, ndvi_feats], dim=-1) |
| if self.separate_heads: |
| outputs = [head(features) for head in self.heads] |
| return torch.cat(outputs, dim=-1) |
| else: |
| return self.head(features) |
|
|
|
|
| |
| |
| |
| class TestDataset(Dataset): |
| def __init__(self, image_dir, df, transform=None, img_size=224, use_ndvi=False): |
| self.image_dir = Path(image_dir) |
| self.df = df.reset_index(drop=True) |
| self.transform = transform |
| self.img_size = img_size |
| self.use_ndvi = use_ndvi |
| |
| def __len__(self): |
| return len(self.df) |
| |
| def __getitem__(self, idx): |
| row = self.df.iloc[idx] |
| img_id = row['image_id'] if 'image_id' in row.index else row.name |
| |
| img_path = self.image_dir / f"{img_id}.jpg" |
| if not img_path.exists(): |
| img_path = self.image_dir / f"{img_id}.png" |
| if not img_path.exists(): |
| candidates = list(self.image_dir.glob(f"{img_id}.*")) |
| if candidates: |
| img_path = candidates[0] |
| |
| img = Image.open(img_path).convert('RGB') |
| img = np.array(img) |
| |
| if self.transform is not None: |
| if HAS_ALBUMENTATIONS: |
| augmented = self.transform(image=img) |
| img_tensor = augmented['image'] |
| else: |
| img = Image.fromarray(img) |
| img_tensor = self.transform(img) |
| else: |
| img_tensor = transforms.ToTensor()(Image.fromarray(img)) |
| |
| result = {'image': img_tensor, 'image_id': str(img_id)} |
| |
| if self.use_ndvi and 'NDVI' in self.df.columns: |
| result['ndvi'] = torch.tensor(row['NDVI'], dtype=torch.float32) |
| |
| return result |
|
|
|
|
| |
| |
| |
| def get_val_transforms(img_size=224): |
| if HAS_ALBUMENTATIONS: |
| return A.Compose([ |
| A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)), |
| A.CenterCrop(height=img_size, width=img_size), |
| A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ToTensorV2(), |
| ]) |
| else: |
| return transforms.Compose([ |
| transforms.Resize(int(img_size * 1.14)), |
| transforms.CenterCrop(img_size), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ]) |
|
|
|
|
| def get_tta_transforms(img_size=224): |
| """Returns list of TTA transforms.""" |
| ttas = [get_val_transforms(img_size)] |
| |
| if HAS_ALBUMENTATIONS: |
| |
| ttas.append(A.Compose([ |
| A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)), |
| A.CenterCrop(height=img_size, width=img_size), |
| A.HorizontalFlip(p=1.0), |
| A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ToTensorV2(), |
| ])) |
| |
| ttas.append(A.Compose([ |
| A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)), |
| A.CenterCrop(height=img_size, width=img_size), |
| A.VerticalFlip(p=1.0), |
| A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ToTensorV2(), |
| ])) |
| |
| ttas.append(A.Compose([ |
| A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)), |
| A.CenterCrop(height=img_size, width=img_size), |
| A.HorizontalFlip(p=1.0), |
| A.VerticalFlip(p=1.0), |
| A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), |
| ToTensorV2(), |
| ])) |
| |
| return ttas |
|
|
|
|
| |
| |
| |
| def load_model(checkpoint_path: str, device: torch.device) -> BiomassModel: |
| """Load a trained model from checkpoint.""" |
| checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| args = checkpoint.get('args', {}) |
| |
| backbone_key = args.get('backbone', 'dinov2_base') |
| backbone_cfg = BACKBONE_CONFIGS[backbone_key] |
| img_size = args.get('img_size', None) or backbone_cfg.get('default_size', 224) |
| |
| model = BiomassModel( |
| backbone_name=backbone_cfg['name'], |
| num_targets=5, |
| hidden_dim=args.get('hidden_dim', 512), |
| dropout=args.get('dropout', 0.3), |
| pretrained=False, |
| img_size=img_size, |
| use_ndvi=args.get('use_ndvi', False), |
| separate_heads=args.get('separate_heads', False), |
| ) |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model = model.to(device) |
| model.eval() |
| |
| print(f"Loaded model from {checkpoint_path}") |
| print(f" Backbone: {backbone_key}, img_size: {img_size}") |
| print(f" Best R²: {checkpoint.get('weighted_r2', 'N/A')}") |
| |
| return model, args |
|
|
|
|
| @torch.no_grad() |
| def predict_single(model, loader, device, log_transform=True): |
| """Predict with a single model.""" |
| model.eval() |
| all_preds = [] |
| all_ids = [] |
| |
| for batch in loader: |
| images = batch['image'].to(device) |
| ndvi = batch.get('ndvi', None) |
| if ndvi is not None: |
| ndvi = ndvi.to(device) |
| |
| with autocast(dtype=torch.float16): |
| preds = model(images, ndvi) |
| |
| all_preds.append(preds.cpu().numpy()) |
| all_ids.extend(batch['image_id']) |
| |
| all_preds = np.concatenate(all_preds, axis=0) |
| |
| if log_transform: |
| all_preds = np.expm1(all_preds) |
| |
| return all_preds, all_ids |
|
|
|
|
| def predict_with_tta(model, test_df, image_dir, device, img_size=224, |
| log_transform=True, batch_size=32, num_workers=4, |
| use_ndvi=False, n_tta=4): |
| """Predict with Test-Time Augmentation.""" |
| tta_transforms = get_tta_transforms(img_size)[:n_tta] |
| |
| all_tta_preds = [] |
| image_ids = None |
| |
| for tta_idx, tfm in enumerate(tta_transforms): |
| dataset = TestDataset(image_dir, test_df, transform=tfm, |
| img_size=img_size, use_ndvi=use_ndvi) |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, |
| num_workers=num_workers, pin_memory=True) |
| |
| preds, ids = predict_single(model, loader, device, log_transform) |
| all_tta_preds.append(preds) |
| |
| if image_ids is None: |
| image_ids = ids |
| |
| print(f" TTA {tta_idx}: mean pred = {preds.mean():.2f}") |
| |
| |
| avg_preds = np.mean(all_tta_preds, axis=0) |
| return avg_preds, image_ids |
|
|
|
|
| def ensemble_predict( |
| model_dirs: List[str], |
| test_df: pd.DataFrame, |
| image_dir: str, |
| device: torch.device, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| n_tta: int = 4, |
| model_weights: Optional[List[float]] = None, |
| ): |
| """ |
| Ensemble predictions from multiple model checkpoints. |
| |
| Args: |
| model_dirs: List of model checkpoint directories (each containing best_model.pth) |
| test_df: Test dataframe |
| image_dir: Test image directory |
| device: torch device |
| batch_size: Batch size for inference |
| num_workers: DataLoader workers |
| n_tta: Number of TTA augmentations |
| model_weights: Optional weights for each model (default: equal weights) |
| """ |
| all_preds = [] |
| image_ids = None |
| |
| for i, model_dir in enumerate(model_dirs): |
| model_dir = Path(model_dir) |
| |
| |
| ckpt_path = model_dir / 'best_model.pth' |
| if not ckpt_path.exists(): |
| |
| ckpt_files = list(model_dir.glob('*.pth')) |
| if ckpt_files: |
| ckpt_path = ckpt_files[0] |
| else: |
| print(f"No checkpoint found in {model_dir}, skipping") |
| continue |
| |
| print(f"\nModel {i+1}/{len(model_dirs)}: {ckpt_path}") |
| |
| model, args = load_model(str(ckpt_path), device) |
| |
| backbone_key = args.get('backbone', 'dinov2_base') |
| backbone_cfg = BACKBONE_CONFIGS[backbone_key] |
| img_size = args.get('img_size', None) or backbone_cfg.get('default_size', 224) |
| log_transform = args.get('log_transform', True) |
| use_ndvi = args.get('use_ndvi', False) |
| |
| preds, ids = predict_with_tta( |
| model, test_df, image_dir, device, |
| img_size=img_size, |
| log_transform=log_transform, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| use_ndvi=use_ndvi, |
| n_tta=n_tta, |
| ) |
| |
| all_preds.append(preds) |
| if image_ids is None: |
| image_ids = ids |
| |
| |
| del model |
| torch.cuda.empty_cache() |
| |
| |
| if model_weights is None: |
| model_weights = [1.0 / len(all_preds)] * len(all_preds) |
| else: |
| model_weights = np.array(model_weights) / sum(model_weights) |
| |
| ensemble_preds = np.zeros_like(all_preds[0]) |
| for pred, weight in zip(all_preds, model_weights): |
| ensemble_preds += weight * pred |
| |
| return ensemble_preds, image_ids |
|
|
|
|
| |
| |
| |
| def postprocess_predictions(preds: np.ndarray) -> np.ndarray: |
| """ |
| Post-process predictions: |
| 1. Clip to non-negative values |
| 2. Ensure Dry_Total >= sum of components (optional soft constraint) |
| """ |
| preds = np.clip(preds, 0, None) |
| |
| |
| component_sum = preds[:, 0] + preds[:, 1] + preds[:, 2] |
| total = preds[:, 4] |
| |
| |
| mask = total < component_sum |
| if mask.any(): |
| preds[mask, 4] = component_sum[mask] * 1.0 |
| |
| return preds |
|
|
|
|
| def create_submission(preds: np.ndarray, image_ids: List[str], output_path: str = 'submission.csv'): |
| """Create submission CSV in competition format.""" |
| rows = [] |
| for i, img_id in enumerate(image_ids): |
| for j, target_name in enumerate(TARGET_COLS): |
| rows.append({ |
| 'sample_id': f"{img_id}__{target_name}", |
| 'target': float(max(0, preds[i, j])), |
| }) |
| |
| sub_df = pd.DataFrame(rows) |
| sub_df.to_csv(output_path, index=False) |
| print(f"Submission saved: {output_path} ({len(sub_df)} rows)") |
| |
| |
| n_images = len(image_ids) |
| expected_rows = n_images * 5 |
| assert len(sub_df) == expected_rows, f"Expected {expected_rows} rows, got {len(sub_df)}" |
| |
| |
| for j, name in enumerate(TARGET_COLS): |
| col_preds = preds[:, j] |
| print(f" {name}: mean={col_preds.mean():.2f}, std={col_preds.std():.2f}, " |
| f"min={col_preds.min():.2f}, max={col_preds.max():.2f}") |
| |
| return sub_df |
|
|
|
|
| |
| |
| |
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--data_dir', type=str, required=True) |
| parser.add_argument('--model_dir', type=str, required=True, |
| help='Directory with fold_0/best_model.pth, fold_1/best_model.pth, etc.') |
| parser.add_argument('--output', type=str, default='submission.csv') |
| parser.add_argument('--batch_size', type=int, default=32) |
| parser.add_argument('--num_workers', type=int, default=4) |
| parser.add_argument('--n_tta', type=int, default=4) |
| parser.add_argument('--no_tta', action='store_true') |
| args = parser.parse_args() |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| |
| |
| data_dir = Path(args.data_dir) |
| test_df = pd.read_csv(data_dir / 'test.csv') |
| |
| test_img_dir = data_dir / 'test_images' |
| if not test_img_dir.exists(): |
| test_img_dir = data_dir / 'test' |
| |
| print(f"Test samples: {len(test_df)}") |
| |
| |
| model_dir = Path(args.model_dir) |
| fold_dirs = sorted(model_dir.glob('fold_*')) |
| |
| if not fold_dirs: |
| |
| fold_dirs = [model_dir] |
| |
| print(f"Found {len(fold_dirs)} fold model(s)") |
| |
| |
| n_tta = 1 if args.no_tta else args.n_tta |
| |
| preds, image_ids = ensemble_predict( |
| model_dirs=[str(d) for d in fold_dirs], |
| test_df=test_df, |
| image_dir=str(test_img_dir), |
| device=device, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| n_tta=n_tta, |
| ) |
| |
| |
| preds = postprocess_predictions(preds) |
| |
| |
| create_submission(preds, image_ids, args.output) |
| print("\nDone!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|