""" src/train_single.py ------------------- Train any single model by name. Designed for running baselines one at a time with breaks between them. Available models ---------------- proposed — ViT-Base + hierarchical KL+MSE (main model) b1_resnet_mse — ResNet-18 + independent MSE (sigmoid) b2_resnet_kl — ResNet-18 + hierarchical KL+MSE b3_vit_mse — ViT-Base + hierarchical MSE only (no KL) b4_vit_dir — ViT-Base + Dirichlet NLL (Zoobot-style) Usage ----- # Train proposed model python -m src.train_single --model proposed --config configs/full_train.yaml # Train one baseline at a time python -m src.train_single --model b1_resnet_mse --config configs/full_train.yaml python -m src.train_single --model b2_resnet_kl --config configs/full_train.yaml python -m src.train_single --model b3_vit_mse --config configs/full_train.yaml python -m src.train_single --model b4_vit_dir --config configs/full_train.yaml # With nohup (recommended) nohup python -m src.train_single --model b3_vit_mse \\ --config configs/full_train.yaml \\ > outputs/logs/train_b3_vit_mse.log 2>&1 & echo "PID: $!" Each model saves its checkpoint independently, so you can run them in any order and resume from any point. Already-trained models are detected by their checkpoint file and skipped unless --force is passed. """ import argparse import logging import sys from pathlib import Path import numpy as np import torch from omegaconf import OmegaConf logging.basicConfig( format="%(asctime)s %(levelname)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("train_single") # ── Checkpoint paths per model ───────────────────────────────────────────────── CHECKPOINT_NAMES = { "proposed" : "best_full_train.pt", "b1_resnet_mse" : "baseline_resnet18_mse.pt", "b2_resnet_kl" : "baseline_resnet18_klmse.pt", "b3_vit_mse" : "baseline_vit_mse.pt", "b4_vit_dir" : "baseline_vit_dirichlet.pt", } # ── Human-readable labels ────────────────────────────────────────────────────── MODEL_LABELS = { "proposed" : "ViT-Base + hierarchical KL+MSE (proposed)", "b1_resnet_mse" : "ResNet-18 + independent MSE (sigmoid, no hierarchy)", "b2_resnet_kl" : "ResNet-18 + hierarchical KL+MSE", "b3_vit_mse" : "ViT-Base + hierarchical MSE only (no KL)", "b4_vit_dir" : "ViT-Base + Dirichlet NLL (Zoobot-style)", } def train_proposed(cfg, device, ckpt_path): """Train the proposed ViT + hierarchical KL+MSE model.""" from src.train import ( train_one_epoch, validate, EarlyStopping, set_seed ) from src.dataset import build_dataloaders from src.model import build_model from src.loss import HierarchicalLoss from src.attention_viz import plot_attention_grid import pandas as pd import wandb from torch.amp import GradScaler import matplotlib.pyplot as plt set_seed(cfg.seed) log.info("Training: %s", MODEL_LABELS["proposed"]) Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True) Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) history_path = str( Path(cfg.outputs.log_dir) / "training_full_train_history.csv" ) if cfg.wandb.enabled: wandb.init( project=cfg.wandb.project, name=cfg.experiment_name, config=OmegaConf.to_container(cfg, resolve=True), ) train_loader, val_loader, _ = build_dataloaders(cfg) model = build_model(cfg).to(device) loss_fn = HierarchicalLoss(cfg) optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min ) scaler = GradScaler("cuda") early_stop = EarlyStopping( patience=cfg.early_stopping.patience, min_delta=cfg.early_stopping.min_delta, checkpoint_path=ckpt_path, ) history = [] for epoch in range(1, cfg.training.epochs + 1): train_loss = train_one_epoch( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch ) collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0) val_logs, attn_data = validate( model, val_loader, loss_fn, device, cfg, collect_attn=collect_attn, n_attn=cfg.wandb.n_attention_samples, epoch=epoch, ) scheduler.step() lr = scheduler.get_last_lr()[0] val_mae = val_logs.get("val/mae/weighted_avg", 0) val_loss = val_logs["val/loss_total"] log.info("Epoch %d train=%.4f val=%.4f mae=%.4f lr=%.2e", epoch, train_loss, val_loss, val_mae, lr) history.append({ "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, "lr": lr, }) if cfg.wandb.enabled: log_dict = {"train/loss": train_loss, **val_logs, "lr": lr, "epoch": epoch} if attn_data is not None: imgs, layers, ids = attn_data fig = plot_attention_grid( imgs, layers, ids, save_path=(f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/" f"attn_epoch{epoch:03d}.png"), n_cols=4, rollout_mode="full", ) log_dict["attention/rollout_full"] = wandb.Image(fig) plt.close(fig) wandb.log(log_dict, step=epoch) if early_stop.step(val_loss, model, epoch): log.info("Early stopping at epoch %d", epoch) break pd.DataFrame(history).to_csv(history_path, index=False) early_stop.restore_best(model) if cfg.wandb.enabled: wandb.finish() log.info("Done. Checkpoint: %s", ckpt_path) def train_baseline(cfg, device, ckpt_path, model_key): """Train any of the four baselines.""" import wandb from torch.amp import GradScaler from src.dataset import build_dataloaders from src.model import build_model, build_dirichlet_model from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss from src.metrics import (compute_metrics, predictions_to_numpy, dirichlet_predictions_to_numpy) from src.baselines import ( ResNet18Baseline, IndependentMSELoss, EarlyStopping, set_seed, _train_epoch, _val_epoch, _train_epoch_dirichlet, _val_epoch_dirichlet, ) import pandas as pd from omegaconf import OmegaConf as OC set_seed(cfg.seed) log.info("Training: %s", MODEL_LABELS[model_key]) Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) # ── Build model and loss ─────────────────────────────────── if model_key == "b1_resnet_mse": model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) loss_fn = IndependentMSELoss() use_sigmoid = True is_dirichlet = False use_layerwise_lr = False wandb_name = "B1-ResNet18-MSE" elif model_key == "b2_resnet_kl": model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) loss_fn = HierarchicalLoss(cfg) use_sigmoid = False is_dirichlet = False use_layerwise_lr = False wandb_name = "B2-ResNet18-KL+MSE" elif model_key == "b3_vit_mse": vit_mse_cfg = OC.merge( cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}}) ) model = build_model(vit_mse_cfg).to(device) loss_fn = MSEOnlyLoss(vit_mse_cfg) cfg = vit_mse_cfg # use updated cfg for optimizer use_sigmoid = False is_dirichlet = False use_layerwise_lr = True wandb_name = "B3-ViT-MSE" elif model_key == "b4_vit_dir": model = build_dirichlet_model(cfg).to(device) loss_fn = DirichletLoss(cfg) use_sigmoid = False is_dirichlet = True use_layerwise_lr = True wandb_name = "B4-ViT-Dirichlet" else: raise ValueError(f"Unknown model key: {model_key}") total = sum(p.numel() for p in model.parameters()) log.info("Parameters: %s", f"{total:,}") # ── Optimizer ────────────────────────────────────────────── if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"): optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) else: optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.training.learning_rate, weight_decay=cfg.training.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min ) scaler = GradScaler("cuda") early_stop = EarlyStopping( patience=cfg.early_stopping.patience, min_delta=cfg.early_stopping.min_delta, checkpoint_path=ckpt_path, ) train_loader, val_loader, test_loader = build_dataloaders(cfg) wandb.init( project=cfg.wandb.project, name=wandb_name, config={"model": wandb_name, "seed": cfg.seed, "epochs": cfg.training.epochs, "lambda_kl": cfg.loss.lambda_kl}, reinit=True, ) # ── Training loop ────────────────────────────────────────── history = [] for epoch in range(1, cfg.training.epochs + 1): if is_dirichlet: train_loss = _train_epoch_dirichlet( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, wandb_name ) val_loss, val_metrics = _val_epoch_dirichlet( model, val_loader, loss_fn, device, cfg, epoch, wandb_name ) else: train_loss = _train_epoch( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch, wandb_name ) val_loss, val_metrics = _val_epoch( model, val_loader, loss_fn, device, cfg, epoch, wandb_name, use_sigmoid=use_sigmoid ) scheduler.step() lr = scheduler.get_last_lr()[0] val_mae = val_metrics.get("mae/weighted_avg", 0) log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", wandb_name, epoch, train_loss, val_loss, val_mae, lr) history.append({ "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, }) wandb.log({ "train_loss": train_loss, "val_loss": val_loss, "val_mae": val_mae, "lr": lr, }, step=epoch) if early_stop.step(val_loss, model, epoch): log.info("%s: early stopping at epoch %d", wandb_name, epoch) break best_val = early_stop.restore_best(model) wandb.finish() # ── Test evaluation ──────────────────────────────────────── log.info("Evaluating on test set...") if is_dirichlet: _, test_metrics = _val_epoch_dirichlet( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{wandb_name}-test" ) else: _, test_metrics = _val_epoch( model, test_loader, loss_fn, device, cfg, epoch=0, label=f"{wandb_name}-test", use_sigmoid=use_sigmoid ) log.info("%s — Test MAE=%.5f RMSE=%.5f", wandb_name, test_metrics["mae/weighted_avg"], test_metrics["rmse/weighted_avg"]) # ── Save per-model history ───────────────────────────────── hist_path = Path(cfg.outputs.log_dir) / f"training_{model_key}_history.csv" pd.DataFrame(history).to_csv(hist_path, index=False) log.info("History saved: %s", hist_path) log.info("Done. Checkpoint: %s", ckpt_path) return test_metrics, best_val, early_stop.best_epoch # ───────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( description="Train a single model. Run multiple times to train " "different models with breaks in between." ) parser.add_argument( "--model", required=True, choices=list(CHECKPOINT_NAMES.keys()), help=( "Which model to train:\n" " proposed — ViT-Base + hierarchical KL+MSE (main)\n" " b1_resnet_mse — ResNet-18 + independent MSE (sigmoid)\n" " b2_resnet_kl — ResNet-18 + hierarchical KL+MSE\n" " b3_vit_mse — ViT-Base + hierarchical MSE only\n" " b4_vit_dir — ViT-Base + Dirichlet NLL\n" ), ) parser.add_argument("--config", required=True) parser.add_argument( "--force", action="store_true", help="Retrain even if checkpoint already exists.", ) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ckpt_dir = Path(cfg.outputs.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True ckpt_path = str(ckpt_dir / CHECKPOINT_NAMES[args.model]) # ── Skip if already done ─────────────────────────────────── if Path(ckpt_path).exists() and not args.force: log.info("Checkpoint already exists: %s", ckpt_path) log.info("Model '%s' is already trained. Skipping.", args.model) log.info("Use --force to retrain.") return log.info("=" * 60) log.info("Training: %s", MODEL_LABELS[args.model]) log.info("Device : %s", device) log.info("Config : %s", args.config) log.info("Ckpt : %s", ckpt_path) log.info("=" * 60) if args.model == "proposed": train_proposed(cfg, device, ckpt_path) else: train_baseline(cfg, device, ckpt_path, args.model) log.info("=" * 60) log.info("FINISHED: %s", MODEL_LABELS[args.model]) log.info("=" * 60) if __name__ == "__main__": main()