eshwar-gz2-api / src /train_single.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/train_single.py
-------------------
Train any single model by name. Designed for running baselines
one at a time with breaks between them.
Available models
----------------
proposed β€” ViT-Base + hierarchical KL+MSE (main model)
b1_resnet_mse β€” ResNet-18 + independent MSE (sigmoid)
b2_resnet_kl β€” ResNet-18 + hierarchical KL+MSE
b3_vit_mse β€” ViT-Base + hierarchical MSE only (no KL)
b4_vit_dir β€” ViT-Base + Dirichlet NLL (Zoobot-style)
Usage
-----
# Train proposed model
python -m src.train_single --model proposed --config configs/full_train.yaml
# Train one baseline at a time
python -m src.train_single --model b1_resnet_mse --config configs/full_train.yaml
python -m src.train_single --model b2_resnet_kl --config configs/full_train.yaml
python -m src.train_single --model b3_vit_mse --config configs/full_train.yaml
python -m src.train_single --model b4_vit_dir --config configs/full_train.yaml
# With nohup (recommended)
nohup python -m src.train_single --model b3_vit_mse \\
--config configs/full_train.yaml \\
> outputs/logs/train_b3_vit_mse.log 2>&1 &
echo "PID: $!"
Each model saves its checkpoint independently, so you can run them
in any order and resume from any point. Already-trained models are
detected by their checkpoint file and skipped unless --force is passed.
"""
import argparse
import logging
import sys
from pathlib import Path
import numpy as np
import torch
from omegaconf import OmegaConf
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("train_single")
# ── Checkpoint paths per model ─────────────────────────────────────────────────
CHECKPOINT_NAMES = {
"proposed" : "best_full_train.pt",
"b1_resnet_mse" : "baseline_resnet18_mse.pt",
"b2_resnet_kl" : "baseline_resnet18_klmse.pt",
"b3_vit_mse" : "baseline_vit_mse.pt",
"b4_vit_dir" : "baseline_vit_dirichlet.pt",
}
# ── Human-readable labels ──────────────────────────────────────────────────────
MODEL_LABELS = {
"proposed" : "ViT-Base + hierarchical KL+MSE (proposed)",
"b1_resnet_mse" : "ResNet-18 + independent MSE (sigmoid, no hierarchy)",
"b2_resnet_kl" : "ResNet-18 + hierarchical KL+MSE",
"b3_vit_mse" : "ViT-Base + hierarchical MSE only (no KL)",
"b4_vit_dir" : "ViT-Base + Dirichlet NLL (Zoobot-style)",
}
def train_proposed(cfg, device, ckpt_path):
"""Train the proposed ViT + hierarchical KL+MSE model."""
from src.train import (
train_one_epoch, validate, EarlyStopping, set_seed
)
from src.dataset import build_dataloaders
from src.model import build_model
from src.loss import HierarchicalLoss
from src.attention_viz import plot_attention_grid
import pandas as pd
import wandb
from torch.amp import GradScaler
import matplotlib.pyplot as plt
set_seed(cfg.seed)
log.info("Training: %s", MODEL_LABELS["proposed"])
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
history_path = str(
Path(cfg.outputs.log_dir) / "training_full_train_history.csv"
)
if cfg.wandb.enabled:
wandb.init(
project=cfg.wandb.project,
name=cfg.experiment_name,
config=OmegaConf.to_container(cfg, resolve=True),
)
train_loader, val_loader, _ = build_dataloaders(cfg)
model = build_model(cfg).to(device)
loss_fn = HierarchicalLoss(cfg)
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
)
scaler = GradScaler("cuda")
early_stop = EarlyStopping(
patience=cfg.early_stopping.patience,
min_delta=cfg.early_stopping.min_delta,
checkpoint_path=ckpt_path,
)
history = []
for epoch in range(1, cfg.training.epochs + 1):
train_loss = train_one_epoch(
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
)
collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
val_logs, attn_data = validate(
model, val_loader, loss_fn, device, cfg,
collect_attn=collect_attn,
n_attn=cfg.wandb.n_attention_samples,
epoch=epoch,
)
scheduler.step()
lr = scheduler.get_last_lr()[0]
val_mae = val_logs.get("val/mae/weighted_avg", 0)
val_loss = val_logs["val/loss_total"]
log.info("Epoch %d train=%.4f val=%.4f mae=%.4f lr=%.2e",
epoch, train_loss, val_loss, val_mae, lr)
history.append({
"epoch": epoch, "train_loss": train_loss,
"val_loss": val_loss, "val_mae": val_mae, "lr": lr,
})
if cfg.wandb.enabled:
log_dict = {"train/loss": train_loss, **val_logs,
"lr": lr, "epoch": epoch}
if attn_data is not None:
imgs, layers, ids = attn_data
fig = plot_attention_grid(
imgs, layers, ids,
save_path=(f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
f"attn_epoch{epoch:03d}.png"),
n_cols=4, rollout_mode="full",
)
log_dict["attention/rollout_full"] = wandb.Image(fig)
plt.close(fig)
wandb.log(log_dict, step=epoch)
if early_stop.step(val_loss, model, epoch):
log.info("Early stopping at epoch %d", epoch)
break
pd.DataFrame(history).to_csv(history_path, index=False)
early_stop.restore_best(model)
if cfg.wandb.enabled:
wandb.finish()
log.info("Done. Checkpoint: %s", ckpt_path)
def train_baseline(cfg, device, ckpt_path, model_key):
"""Train any of the four baselines."""
import wandb
from torch.amp import GradScaler
from src.dataset import build_dataloaders
from src.model import build_model, build_dirichlet_model
from src.loss import HierarchicalLoss, DirichletLoss, MSEOnlyLoss
from src.metrics import (compute_metrics, predictions_to_numpy,
dirichlet_predictions_to_numpy)
from src.baselines import (
ResNet18Baseline, IndependentMSELoss, EarlyStopping,
set_seed, _train_epoch, _val_epoch,
_train_epoch_dirichlet, _val_epoch_dirichlet,
)
import pandas as pd
from omegaconf import OmegaConf as OC
set_seed(cfg.seed)
log.info("Training: %s", MODEL_LABELS[model_key])
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
# ── Build model and loss ───────────────────────────────────
if model_key == "b1_resnet_mse":
model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
loss_fn = IndependentMSELoss()
use_sigmoid = True
is_dirichlet = False
use_layerwise_lr = False
wandb_name = "B1-ResNet18-MSE"
elif model_key == "b2_resnet_kl":
model = ResNet18Baseline(dropout=cfg.model.dropout).to(device)
loss_fn = HierarchicalLoss(cfg)
use_sigmoid = False
is_dirichlet = False
use_layerwise_lr = False
wandb_name = "B2-ResNet18-KL+MSE"
elif model_key == "b3_vit_mse":
vit_mse_cfg = OC.merge(
cfg, OC.create({"loss": {"lambda_kl": 0.0, "lambda_mse": 1.0}})
)
model = build_model(vit_mse_cfg).to(device)
loss_fn = MSEOnlyLoss(vit_mse_cfg)
cfg = vit_mse_cfg # use updated cfg for optimizer
use_sigmoid = False
is_dirichlet = False
use_layerwise_lr = True
wandb_name = "B3-ViT-MSE"
elif model_key == "b4_vit_dir":
model = build_dirichlet_model(cfg).to(device)
loss_fn = DirichletLoss(cfg)
use_sigmoid = False
is_dirichlet = True
use_layerwise_lr = True
wandb_name = "B4-ViT-Dirichlet"
else:
raise ValueError(f"Unknown model key: {model_key}")
total = sum(p.numel() for p in model.parameters())
log.info("Parameters: %s", f"{total:,}")
# ── Optimizer ──────────────────────────────────────────────
if use_layerwise_lr and hasattr(model, "backbone") and hasattr(model, "head"):
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
else:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=cfg.training.learning_rate,
weight_decay=cfg.training.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
)
scaler = GradScaler("cuda")
early_stop = EarlyStopping(
patience=cfg.early_stopping.patience,
min_delta=cfg.early_stopping.min_delta,
checkpoint_path=ckpt_path,
)
train_loader, val_loader, test_loader = build_dataloaders(cfg)
wandb.init(
project=cfg.wandb.project, name=wandb_name,
config={"model": wandb_name, "seed": cfg.seed,
"epochs": cfg.training.epochs,
"lambda_kl": cfg.loss.lambda_kl},
reinit=True,
)
# ── Training loop ──────────────────────────────────────────
history = []
for epoch in range(1, cfg.training.epochs + 1):
if is_dirichlet:
train_loss = _train_epoch_dirichlet(
model, train_loader, loss_fn, optimizer, scaler,
device, cfg, epoch, wandb_name
)
val_loss, val_metrics = _val_epoch_dirichlet(
model, val_loader, loss_fn, device, cfg, epoch, wandb_name
)
else:
train_loss = _train_epoch(
model, train_loader, loss_fn, optimizer, scaler,
device, cfg, epoch, wandb_name
)
val_loss, val_metrics = _val_epoch(
model, val_loader, loss_fn, device, cfg, epoch, wandb_name,
use_sigmoid=use_sigmoid
)
scheduler.step()
lr = scheduler.get_last_lr()[0]
val_mae = val_metrics.get("mae/weighted_avg", 0)
log.info("%s epoch=%d train=%.4f val=%.4f mae=%.4f lr=%.2e",
wandb_name, epoch, train_loss, val_loss, val_mae, lr)
history.append({
"epoch": epoch, "train_loss": train_loss,
"val_loss": val_loss, "val_mae": val_mae,
})
wandb.log({
"train_loss": train_loss, "val_loss": val_loss,
"val_mae": val_mae, "lr": lr,
}, step=epoch)
if early_stop.step(val_loss, model, epoch):
log.info("%s: early stopping at epoch %d", wandb_name, epoch)
break
best_val = early_stop.restore_best(model)
wandb.finish()
# ── Test evaluation ────────────────────────────────────────
log.info("Evaluating on test set...")
if is_dirichlet:
_, test_metrics = _val_epoch_dirichlet(
model, test_loader, loss_fn, device, cfg,
epoch=0, label=f"{wandb_name}-test"
)
else:
_, test_metrics = _val_epoch(
model, test_loader, loss_fn, device, cfg,
epoch=0, label=f"{wandb_name}-test", use_sigmoid=use_sigmoid
)
log.info("%s β€” Test MAE=%.5f RMSE=%.5f",
wandb_name,
test_metrics["mae/weighted_avg"],
test_metrics["rmse/weighted_avg"])
# ── Save per-model history ─────────────────────────────────
hist_path = Path(cfg.outputs.log_dir) / f"training_{model_key}_history.csv"
pd.DataFrame(history).to_csv(hist_path, index=False)
log.info("History saved: %s", hist_path)
log.info("Done. Checkpoint: %s", ckpt_path)
return test_metrics, best_val, early_stop.best_epoch
# ─────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Train a single model. Run multiple times to train "
"different models with breaks in between."
)
parser.add_argument(
"--model",
required=True,
choices=list(CHECKPOINT_NAMES.keys()),
help=(
"Which model to train:\n"
" proposed β€” ViT-Base + hierarchical KL+MSE (main)\n"
" b1_resnet_mse β€” ResNet-18 + independent MSE (sigmoid)\n"
" b2_resnet_kl β€” ResNet-18 + hierarchical KL+MSE\n"
" b3_vit_mse β€” ViT-Base + hierarchical MSE only\n"
" b4_vit_dir β€” ViT-Base + Dirichlet NLL\n"
),
)
parser.add_argument("--config", required=True)
parser.add_argument(
"--force",
action="store_true",
help="Retrain even if checkpoint already exists.",
)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_dir = Path(cfg.outputs.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
ckpt_path = str(ckpt_dir / CHECKPOINT_NAMES[args.model])
# ── Skip if already done ───────────────────────────────────
if Path(ckpt_path).exists() and not args.force:
log.info("Checkpoint already exists: %s", ckpt_path)
log.info("Model '%s' is already trained. Skipping.", args.model)
log.info("Use --force to retrain.")
return
log.info("=" * 60)
log.info("Training: %s", MODEL_LABELS[args.model])
log.info("Device : %s", device)
log.info("Config : %s", args.config)
log.info("Ckpt : %s", ckpt_path)
log.info("=" * 60)
if args.model == "proposed":
train_proposed(cfg, device, ckpt_path)
else:
train_baseline(cfg, device, ckpt_path, args.model)
log.info("=" * 60)
log.info("FINISHED: %s", MODEL_LABELS[args.model])
log.info("=" * 60)
if __name__ == "__main__":
main()