csiro-image2biomass / kaggle_inference_notebook.py
notRaphael's picture
Upload kaggle_inference_notebook.py with huggingface_hub
a1c4356 verified
#!/usr/bin/env python3
"""
CSIRO Image2Biomass Prediction - Kaggle Inference Notebook
============================================================
This notebook loads trained models and generates submission.csv.
Requirements:
- Trained model weights saved as a Kaggle dataset
- No internet access (all models pre-downloaded)
Expected model dataset structure:
/kaggle/input/biomass-models/
fold_0/best_model.pth
fold_1/best_model.pth
...
training_info.json
"""
import os
import sys
import json
import time
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast
from PIL import Image
warnings.filterwarnings('ignore')
os.system('pip install -q timm albumentations')
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
# ============================================================
# Configuration
# ============================================================
class CFG:
COMPETITION = 'csiro-biomass'
DATA_DIR = Path(f'/kaggle/input/{COMPETITION}')
MODEL_DIR = Path('/kaggle/input/biomass-models') # Your uploaded model weights
OUTPUT_DIR = Path('/kaggle/working')
BATCH_SIZE = 32
NUM_WORKERS = 2
N_TTA = 4 # Number of TTA augmentations
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
BACKBONE_CONFIGS = {
'dinov2_small': {'name': 'vit_small_patch14_dinov2.lvd142m', 'feat_dim': 384},
'dinov2_base': {'name': 'vit_base_patch14_dinov2.lvd142m', 'feat_dim': 768},
'dinov2_large': {'name': 'vit_large_patch14_dinov2.lvd142m', 'feat_dim': 1024},
'dinov2_base_reg': {'name': 'vit_base_patch14_reg4_dinov2.lvd142m', 'feat_dim': 768},
'convnext_large': {'name': 'convnext_large.fb_in22k_ft_in1k', 'feat_dim': 1536},
'convnextv2_large': {'name': 'convnextv2_large.fcmae_ft_in22k_in1k', 'feat_dim': 1536},
'efficientnet_b4': {'name': 'efficientnet_b4.ra2_in1k', 'feat_dim': 1792},
'swin_large': {'name': 'swin_large_patch4_window7_224.ms_in22k_ft_in1k', 'feat_dim': 1536},
'eva02_large': {'name': 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 'feat_dim': 1024},
}
# ============================================================
# Model Definition (must match training)
# ============================================================
class BiomassModel(nn.Module):
def __init__(self, backbone_name, num_targets=5, hidden_dim=512,
dropout=0.3, pretrained=False, img_size=224,
use_ndvi=False, separate_heads=False):
super().__init__()
self.use_ndvi = use_ndvi
self.separate_heads = separate_heads
kwargs = {'pretrained': pretrained, 'num_classes': 0}
if 'vit' in backbone_name or 'dinov2' in backbone_name:
kwargs['img_size'] = img_size
self.backbone = timm.create_model(backbone_name, **kwargs)
feat_dim = self.backbone.num_features
if use_ndvi:
self.ndvi_embed = nn.Sequential(nn.Linear(1, 32), nn.GELU(), nn.Linear(32, 64))
feat_dim += 64
if separate_heads:
self.heads = nn.ModuleList([
nn.Sequential(
nn.LayerNorm(feat_dim), nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim), nn.GELU(),
nn.Dropout(dropout * 0.5), nn.Linear(hidden_dim, 1),
) for _ in range(num_targets)
])
else:
self.head = nn.Sequential(
nn.LayerNorm(feat_dim), nn.Dropout(dropout),
nn.Linear(feat_dim, hidden_dim), nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(hidden_dim, hidden_dim // 2), nn.GELU(),
nn.Dropout(dropout * 0.3),
nn.Linear(hidden_dim // 2, num_targets),
)
def forward(self, x, ndvi=None):
features = self.backbone(x)
if self.use_ndvi and ndvi is not None:
features = torch.cat([features, self.ndvi_embed(ndvi.unsqueeze(-1))], dim=-1)
if self.separate_heads:
return torch.cat([h(features) for h in self.heads], dim=-1)
return self.head(features)
# ============================================================
# Dataset
# ============================================================
class TestDataset(Dataset):
def __init__(self, image_dir, df, transform, use_ndvi=False):
self.image_dir = Path(image_dir)
self.df = df.reset_index(drop=True)
self.transform = transform
self.use_ndvi = use_ndvi
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
img_id = row['image_id'] if 'image_id' in row.index else row.name
img_path = None
for ext in ['.jpg', '.jpeg', '.png', '.JPG']:
p = self.image_dir / f"{img_id}{ext}"
if p.exists():
img_path = p
break
if img_path is None:
candidates = list(self.image_dir.glob(f"{img_id}*"))
img_path = candidates[0] if candidates else self.image_dir / f"{img_id}.jpg"
img = np.array(Image.open(img_path).convert('RGB'))
img_tensor = self.transform(image=img)['image']
result = {'image': img_tensor, 'image_id': str(img_id)}
if self.use_ndvi and 'NDVI' in self.df.columns:
result['ndvi'] = torch.tensor(float(row['NDVI']), dtype=torch.float32)
return result
# ============================================================
# TTA Transforms
# ============================================================
def get_tta_transforms(img_size=224, n_tta=4):
tfms = []
# 0: Standard center crop
tfms.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# 1: HFlip
tfms.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# 2: VFlip
tfms.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
# 3: Both flips
tfms.append(A.Compose([
A.Resize(height=int(img_size * 1.14), width=int(img_size * 1.14)),
A.CenterCrop(height=img_size, width=img_size),
A.HorizontalFlip(p=1.0),
A.VerticalFlip(p=1.0),
A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
ToTensorV2(),
]))
return tfms[:n_tta]
# ============================================================
# Inference Functions
# ============================================================
def load_model(ckpt_path, device):
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
args = ckpt.get('args', {})
# Resolve backbone name
backbone_key = args.get('backbone', 'vit_base_patch14_dinov2.lvd142m')
if backbone_key in BACKBONE_CONFIGS:
backbone_name = BACKBONE_CONFIGS[backbone_key]['name']
else:
backbone_name = backbone_key
img_size = args.get('img_size', 224)
model = BiomassModel(
backbone_name=backbone_name,
num_targets=5,
hidden_dim=args.get('hidden_dim', 512),
dropout=args.get('dropout', 0.3),
pretrained=False,
img_size=img_size,
use_ndvi=args.get('use_ndvi', False),
separate_heads=args.get('separate_heads', False),
)
model.load_state_dict(ckpt['model_state_dict'])
model = model.to(device).eval()
return model, args
@torch.no_grad()
def predict(model, loader, device, log_transform=True):
model.eval()
preds_list, ids_list = [], []
for batch in loader:
images = batch['image'].to(device)
ndvi = batch.get('ndvi', None)
if ndvi is not None:
ndvi = ndvi.to(device)
with autocast(dtype=torch.float16):
preds = model(images, ndvi)
preds_list.append(preds.cpu().numpy())
ids_list.extend(batch['image_id'])
preds = np.concatenate(preds_list)
if log_transform:
preds = np.expm1(preds)
return preds, ids_list
def predict_tta(model, test_df, img_dir, device, img_size, log_transform,
use_ndvi, batch_size, num_workers, n_tta):
tta_tfms = get_tta_transforms(img_size, n_tta)
all_preds = []
ids = None
for i, tfm in enumerate(tta_tfms):
ds = TestDataset(img_dir, test_df, tfm, use_ndvi)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
p, image_ids = predict(model, loader, device, log_transform)
all_preds.append(p)
if ids is None:
ids = image_ids
return np.mean(all_preds, axis=0), ids
# ============================================================
# Main Inference
# ============================================================
device = torch.device(CFG.DEVICE)
print(f"Device: {device}")
# Find data
for alt in ['/kaggle/input/csiro-biomass', '/kaggle/input/csiro-image2biomass-prediction',
'/kaggle/input/csiro-image2biomass']:
if Path(alt).exists():
CFG.DATA_DIR = Path(alt)
break
# Find model weights
for alt in ['/kaggle/input/biomass-models', '/kaggle/input/biomass-weights',
'/kaggle/working']:
if Path(alt).exists() and list(Path(alt).glob('fold_*')):
CFG.MODEL_DIR = Path(alt)
break
print(f"Data: {CFG.DATA_DIR}")
print(f"Models: {CFG.MODEL_DIR}")
# Load test data
test_csv = None
for fname in ['test.csv', 'Test.csv']:
if (CFG.DATA_DIR / fname).exists():
test_csv = CFG.DATA_DIR / fname
break
test_df = pd.read_csv(test_csv)
print(f"Test samples: {len(test_df)}")
# Find test images
test_img_dir = None
for d in ['test_images', 'test', 'images/test']:
if (CFG.DATA_DIR / d).exists():
test_img_dir = CFG.DATA_DIR / d
break
print(f"Test images: {test_img_dir}")
# Find fold models
fold_dirs = sorted(CFG.MODEL_DIR.glob('fold_*'))
print(f"Found {len(fold_dirs)} fold models")
# Ensemble prediction
all_fold_preds = []
image_ids = None
for fold_dir in fold_dirs:
ckpt_path = fold_dir / 'best_model.pth'
if not ckpt_path.exists():
continue
print(f"\nLoading {ckpt_path}...")
model, args = load_model(str(ckpt_path), device)
img_size = args.get('img_size', 224)
log_transform = args.get('log_transform', True)
use_ndvi = args.get('use_ndvi', False)
preds, ids = predict_tta(
model, test_df, str(test_img_dir), device,
img_size=img_size,
log_transform=log_transform,
use_ndvi=use_ndvi,
batch_size=CFG.BATCH_SIZE,
num_workers=CFG.NUM_WORKERS,
n_tta=CFG.N_TTA,
)
all_fold_preds.append(preds)
if image_ids is None:
image_ids = ids
print(f" Mean predictions: {preds.mean(axis=0)}")
del model
torch.cuda.empty_cache()
# Average across folds
ensemble_preds = np.mean(all_fold_preds, axis=0)
ensemble_preds = np.clip(ensemble_preds, 0, None)
# Post-process: ensure total >= component sum
comp_sum = ensemble_preds[:, 0] + ensemble_preds[:, 1] + ensemble_preds[:, 2]
mask = ensemble_preds[:, 4] < comp_sum
ensemble_preds[mask, 4] = comp_sum[mask]
print(f"\nEnsemble predictions summary:")
for i, name in enumerate(TARGET_COLS):
col = ensemble_preds[:, i]
print(f" {name}: mean={col.mean():.2f}, std={col.std():.2f}, "
f"min={col.min():.2f}, max={col.max():.2f}")
# Create submission
rows = []
for i, img_id in enumerate(image_ids):
for j, target_name in enumerate(TARGET_COLS):
rows.append({
'sample_id': f"{img_id}__{target_name}",
'target': float(max(0, ensemble_preds[i, j])),
})
submission = pd.DataFrame(rows)
submission.to_csv('submission.csv', index=False)
print(f"\nSubmission saved: submission.csv ({len(submission)} rows)")
print(submission.head(10))
# Verify format
assert submission.columns.tolist() == ['sample_id', 'target']
assert len(submission) == len(test_df) * 5
print("\n✅ Submission format verified!")