Spaces:
Running
Running
| """ | |
| src/train_single.py | |
| ------------------- | |
| Train any single model by name. Designed for running baselines | |
| one at a time with breaks between them. | |
| Available models | |
| ---------------- | |
| proposed β ViT-Base + hierarchical KL+MSE (main model) | |
| b1_resnet_mse β ResNet-18 + independent MSE (sigmoid) | |
| b2_resnet_kl β ResNet-18 + hierarchical KL+MSE | |
| b3_vit_mse β ViT-Base + hierarchical MSE only (no KL) | |
| b4_vit_dir β ViT-Base + Dirichlet NLL (Zoobot-style) | |
| Usage | |
| ----- | |
| # Train proposed model | |
| python -m src.train_single --model proposed --config configs/full_train.yaml | |
| # Train one baseline at a time | |
| python -m src.train_single --model b1_resnet_mse --config configs/full_train.yaml | |
| python -m src.train_single --model b2_resnet_kl --config configs/full_train.yaml | |
| python -m src.train_single --model b3_vit_mse --config configs/full_train.yaml | |
| python -m src.train_single --model b4_vit_dir --config configs/full_train.yaml | |
| # With nohup (recommended) | |
| nohup python -m src.train_single --model b3_vit_mse \\ | |
| --config configs/full_train.yaml \\ | |
| > outputs/logs/train_b3_vit_mse.log 2>&1 & | |
| echo "PID: $!" | |
| Each model saves its checkpoint independently, so you can run them | |
| in any order and resume from any point. Already-trained models are | |
| detected by their checkpoint file and skipped unless --force is passed. | |
| """ | |
| import argparse | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from omegaconf import OmegaConf | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("train_single") | |
| # ββ Checkpoint paths per model βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CHECKPOINT_NAMES = { | |
| "proposed" : "best_full_train.pt", | |
| "b1_resnet_mse" : "baseline_resnet18_mse.pt", | |
| "b2_resnet_kl" : "baseline_resnet18_klmse.pt", | |
| "b3_vit_mse" : "baseline_vit_mse.pt", | |
| "b4_vit_dir" : "baseline_vit_dirichlet.pt", | |
| } | |
| # ββ Human-readable labels ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_LABELS = { | |
| "proposed" : "ViT-Base + hierarchical KL+MSE (proposed)", | |
| "b1_resnet_mse" : "ResNet-18 + independent MSE (sigmoid, no hierarchy)", | |
| "b2_resnet_kl" : "ResNet-18 + hierarchical KL+MSE", | |
| "b3_vit_mse" : "ViT-Base + hierarchical MSE only (no KL)", | |
| "b4_vit_dir" : "ViT-Base + Dirichlet NLL (Zoobot-style)", | |
| } | |
| def train_proposed(cfg, device, ckpt_path): | |
| """Train the proposed ViT + hierarchical KL+MSE model.""" | |
| from src.train import ( | |
| train_one_epoch, validate, EarlyStopping, set_seed | |
| ) | |
| from src.dataset import build_dataloaders | |
| from src.model import build_model | |
| from src.loss import HierarchicalLoss | |
| from src.attention_viz import plot_attention_grid | |
| import pandas as pd | |
| import wandb | |
| from torch.amp import GradScaler | |
| import matplotlib.pyplot as plt | |
| set_seed(cfg.seed) | |
| log.info("Training: %s", MODEL_LABELS["proposed"]) | |
| Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) | |
| Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True) | |
| Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) | |
| history_path = str( | |
| Path(cfg.outputs.log_dir) / "training_full_train_history.csv" | |
| ) | |
| if cfg.wandb.enabled: | |
| wandb.init( | |
| project=cfg.wandb.project, | |
| name=cfg.experiment_name, | |
| config=OmegaConf.to_container(cfg, resolve=True), | |
| ) | |
| train_loader, val_loader, _ = build_dataloaders(cfg) | |
| model = build_model(cfg).to(device) | |
| loss_fn = HierarchicalLoss(cfg) | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min | |
| ) | |
| scaler = GradScaler("cuda") | |
| early_stop = EarlyStopping( | |
| patience=cfg.early_stopping.patience, | |
| min_delta=cfg.early_stopping.min_delta, | |
| checkpoint_path=ckpt_path, | |
| ) | |
| history = [] | |
| for epoch in range(1, cfg.training.epochs + 1): | |
| train_loss = train_one_epoch( | |
| model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch | |
| ) | |
| collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0) | |
| val_logs, attn_data = validate( | |
| model, val_loader, loss_fn, device, cfg, | |
| collect_attn=collect_attn, | |
| n_attn=cfg.wandb.n_attention_samples, | |
| epoch=epoch, | |
| ) | |
| scheduler.step() | |
| lr = scheduler.get_last_lr()[0] | |
| val_mae = val_logs.get("val/mae/weighted_avg", 0) | |
| val_loss = val_logs["val/loss_total"] | |
| log.info("Epoch %d train=%.4f val=%.4f mae=%.4f lr=%.2e", | |
| epoch, train_loss, val_loss, val_mae, lr) | |
| history.append({ | |
| "epoch": epoch, "train_loss": train_loss, | |
| "val_loss": val_loss, "val_mae": val_mae, "lr": lr, | |
| }) | |
| if cfg.wandb.enabled: | |
| log_dict = {"train/loss": train_loss, **val_logs, | |
| "lr": lr, "epoch": epoch} | |
| if attn_data is not None: | |
| imgs, layers, ids = attn_data | |
| fig = plot_attention_grid( | |
| imgs, layers, ids, | |
| save_path=(f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/" | |
| f"attn_epoch{epoch:03d}.png"), | |
| n_cols=4, rollout_mode="full", | |
| ) | |
| log_dict["attention/rollout_full"] = wandb.Image(fig) | |
| plt.close(fig) | |
| wandb.log(log_dict, step=epoch) | |
| if early_stop.step(val_loss, model, epoch): | |
| log.info("Early stopping at epoch %d", epoch) | |
| break | |
| pd.DataFrame(history).to_csv(history_path, index=False) | |
| early_stop.restore_best(model) | |
| if cfg.wandb.enabled: | |
| wandb.finish() | |
| log.info("Done. Checkpoint: %s", ckpt_path) | |
| def train_baseline(cfg, device, ckpt_path, model_key): | |
| """Train any of the four baselines.""" | |
| import wandb | |
| from torch.amp import GradScaler | |
| from src.dataset import build_dataloaders | |
| from src.model import build_model, build_dirichlet_model | |
| from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss | |
| from src.metrics import (compute_metrics, predictions_to_numpy, | |
| dirichlet_predictions_to_numpy) | |
| from src.baselines import ( | |
| ResNet18Baseline, IndependentMSELoss, EarlyStopping, | |
| set_seed, _train_epoch, _val_epoch, | |
| _train_epoch_dirichlet, _val_epoch_dirichlet, | |
| ) | |
| import pandas as pd | |
| from omegaconf import OmegaConf as OC | |
| set_seed(cfg.seed) | |
| log.info("Training: %s", MODEL_LABELS[model_key]) | |
| Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) | |
| # ββ Build model and loss βββββββββββββββββββββββββββββββββββ | |
| if model_key == "b1_resnet_mse": | |
| model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) | |
| loss_fn = IndependentMSELoss() | |
| use_sigmoid = True | |
| is_dirichlet = False | |
| use_layerwise_lr = False | |
| wandb_name = "B1-ResNet18-MSE" | |
| elif model_key == "b2_resnet_kl": | |
| model = ResNet18Baseline(dropout=cfg.model.dropout).to(device) | |
| loss_fn = HierarchicalLoss(cfg) | |
| use_sigmoid = False | |
| is_dirichlet = False | |
| use_layerwise_lr = False | |
| wandb_name = "B2-ResNet18-KL+MSE" | |
| elif model_key == "b3_vit_mse": | |
| vit_mse_cfg = OC.merge( | |
| cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}}) | |
| ) | |
| model = build_model(vit_mse_cfg).to(device) | |
| loss_fn = MSEOnlyLoss(vit_mse_cfg) | |
| cfg = vit_mse_cfg # use updated cfg for optimizer | |
| use_sigmoid = False | |
| is_dirichlet = False | |
| use_layerwise_lr = True | |
| wandb_name = "B3-ViT-MSE" | |
| elif model_key == "b4_vit_dir": | |
| model = build_dirichlet_model(cfg).to(device) | |
| loss_fn = DirichletLoss(cfg) | |
| use_sigmoid = False | |
| is_dirichlet = True | |
| use_layerwise_lr = True | |
| wandb_name = "B4-ViT-Dirichlet" | |
| else: | |
| raise ValueError(f"Unknown model key: {model_key}") | |
| total = sum(p.numel() for p in model.parameters()) | |
| log.info("Parameters: %s", f"{total:,}") | |
| # ββ Optimizer ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"): | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| else: | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=cfg.training.learning_rate, | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min | |
| ) | |
| scaler = GradScaler("cuda") | |
| early_stop = EarlyStopping( | |
| patience=cfg.early_stopping.patience, | |
| min_delta=cfg.early_stopping.min_delta, | |
| checkpoint_path=ckpt_path, | |
| ) | |
| train_loader, val_loader, test_loader = build_dataloaders(cfg) | |
| wandb.init( | |
| project=cfg.wandb.project, name=wandb_name, | |
| config={"model": wandb_name, "seed": cfg.seed, | |
| "epochs": cfg.training.epochs, | |
| "lambda_kl": cfg.loss.lambda_kl}, | |
| reinit=True, | |
| ) | |
| # ββ Training loop ββββββββββββββββββββββββββββββββββββββββββ | |
| history = [] | |
| for epoch in range(1, cfg.training.epochs + 1): | |
| if is_dirichlet: | |
| train_loss = _train_epoch_dirichlet( | |
| model, train_loader, loss_fn, optimizer, scaler, | |
| device, cfg, epoch, wandb_name | |
| ) | |
| val_loss, val_metrics = _val_epoch_dirichlet( | |
| model, val_loader, loss_fn, device, cfg, epoch, wandb_name | |
| ) | |
| else: | |
| train_loss = _train_epoch( | |
| model, train_loader, loss_fn, optimizer, scaler, | |
| device, cfg, epoch, wandb_name | |
| ) | |
| val_loss, val_metrics = _val_epoch( | |
| model, val_loader, loss_fn, device, cfg, epoch, wandb_name, | |
| use_sigmoid=use_sigmoid | |
| ) | |
| scheduler.step() | |
| lr = scheduler.get_last_lr()[0] | |
| val_mae = val_metrics.get("mae/weighted_avg", 0) | |
| log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e", | |
| wandb_name, epoch, train_loss, val_loss, val_mae, lr) | |
| history.append({ | |
| "epoch": epoch, "train_loss": train_loss, | |
| "val_loss": val_loss, "val_mae": val_mae, | |
| }) | |
| wandb.log({ | |
| "train_loss": train_loss, "val_loss": val_loss, | |
| "val_mae": val_mae, "lr": lr, | |
| }, step=epoch) | |
| if early_stop.step(val_loss, model, epoch): | |
| log.info("%s: early stopping at epoch %d", wandb_name, epoch) | |
| break | |
| best_val = early_stop.restore_best(model) | |
| wandb.finish() | |
| # ββ Test evaluation ββββββββββββββββββββββββββββββββββββββββ | |
| log.info("Evaluating on test set...") | |
| if is_dirichlet: | |
| _, test_metrics = _val_epoch_dirichlet( | |
| model, test_loader, loss_fn, device, cfg, | |
| epoch=0, label=f"{wandb_name}-test" | |
| ) | |
| else: | |
| _, test_metrics = _val_epoch( | |
| model, test_loader, loss_fn, device, cfg, | |
| epoch=0, label=f"{wandb_name}-test", use_sigmoid=use_sigmoid | |
| ) | |
| log.info("%s β Test MAE=%.5f RMSE=%.5f", | |
| wandb_name, | |
| test_metrics["mae/weighted_avg"], | |
| test_metrics["rmse/weighted_avg"]) | |
| # ββ Save per-model history βββββββββββββββββββββββββββββββββ | |
| hist_path = Path(cfg.outputs.log_dir) / f"training_{model_key}_history.csv" | |
| pd.DataFrame(history).to_csv(hist_path, index=False) | |
| log.info("History saved: %s", hist_path) | |
| log.info("Done. Checkpoint: %s", ckpt_path) | |
| return test_metrics, best_val, early_stop.best_epoch | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Train a single model. Run multiple times to train " | |
| "different models with breaks in between." | |
| ) | |
| parser.add_argument( | |
| "--model", | |
| required=True, | |
| choices=list(CHECKPOINT_NAMES.keys()), | |
| help=( | |
| "Which model to train:\n" | |
| " proposed β ViT-Base + hierarchical KL+MSE (main)\n" | |
| " b1_resnet_mse β ResNet-18 + independent MSE (sigmoid)\n" | |
| " b2_resnet_kl β ResNet-18 + hierarchical KL+MSE\n" | |
| " b3_vit_mse β ViT-Base + hierarchical MSE only\n" | |
| " b4_vit_dir β ViT-Base + Dirichlet NLL\n" | |
| ), | |
| ) | |
| parser.add_argument("--config", required=True) | |
| parser.add_argument( | |
| "--force", | |
| action="store_true", | |
| help="Retrain even if checkpoint already exists.", | |
| ) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ckpt_dir = Path(cfg.outputs.checkpoint_dir) | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| ckpt_path = str(ckpt_dir / CHECKPOINT_NAMES[args.model]) | |
| # ββ Skip if already done βββββββββββββββββββββββββββββββββββ | |
| if Path(ckpt_path).exists() and not args.force: | |
| log.info("Checkpoint already exists: %s", ckpt_path) | |
| log.info("Model '%s' is already trained. Skipping.", args.model) | |
| log.info("Use --force to retrain.") | |
| return | |
| log.info("=" * 60) | |
| log.info("Training: %s", MODEL_LABELS[args.model]) | |
| log.info("Device : %s", device) | |
| log.info("Config : %s", args.config) | |
| log.info("Ckpt : %s", ckpt_path) | |
| log.info("=" * 60) | |
| if args.model == "proposed": | |
| train_proposed(cfg, device, ckpt_path) | |
| else: | |
| train_baseline(cfg, device, ckpt_path, args.model) | |
| log.info("=" * 60) | |
| log.info("FINISHED: %s", MODEL_LABELS[args.model]) | |
| log.info("=" * 60) | |
| if __name__ == "__main__": | |
| main() | |