eshwar-gz2-api / src /train.py
sreshwarprasad's picture
Upload folder using huggingface_hub
e36eee4 verified
"""
src/train.py
------------
Main training loop for the proposed hierarchical probabilistic ViT
regression model on Galaxy Zoo 2.
Model : GalaxyViT (ViT-Base/16 + linear head)
Loss : HierarchicalLoss (KL + MSE, Ξ»=0.5 each)
Scheduler: CosineAnnealingLR
Dropout : 0.3 (increased from 0.1 β€” see base.yaml rationale)
Saves
-----
outputs/checkpoints/best_<experiment_name>.pt β€” best checkpoint
outputs/logs/training_<experiment_name>_history.csv β€” epoch history
Usage
-----
cd ~/galaxy
nohup python -m src.train --config configs/full_train.yaml \
> outputs/logs/train_full.log 2>&1 &
echo "PID: $!"
"""
import argparse
import logging
import random
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
from omegaconf import OmegaConf
import pandas as pd
import wandb
from tqdm import tqdm
from src.dataset import build_dataloaders
from src.loss import HierarchicalLoss
from src.metrics import compute_metrics, predictions_to_numpy
from src.model import build_model
from src.attention_viz import plot_attention_grid
logging.basicConfig(
format="%(asctime)s %(levelname)s %(name)s %(message)s",
datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout,
)
log = logging.getLogger("train")
# ─────────────────────────────────────────────────────────────
# Utilities
# ─────────────────────────────────────────────────────────────
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class EarlyStopping:
def __init__(self, patience, min_delta, checkpoint_path):
self.patience = patience
self.min_delta = min_delta
self.checkpoint_path = checkpoint_path
self.best_loss = float("inf")
self.counter = 0
self.best_epoch = 0
def step(self, val_loss, model, epoch) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
self.best_epoch = epoch
torch.save(
{"epoch": epoch, "model_state": model.state_dict(),
"val_loss": val_loss},
self.checkpoint_path,
)
log.info(" [ckpt] saved epoch=%d val_loss=%.6f", epoch, val_loss)
else:
self.counter += 1
log.info(" [early_stop] %d/%d best=%.6f",
self.counter, self.patience, self.best_loss)
return self.counter >= self.patience
def restore_best(self, model):
ckpt = torch.load(self.checkpoint_path, map_location="cpu",
weights_only=True)
model.load_state_dict(ckpt["model_state"])
log.info("Restored best weights epoch=%d val_loss=%.6f",
ckpt["epoch"], ckpt["val_loss"])
# ─────────────────────────────────────────────────────────────
# Training / validation steps
# ─────────────────────────────────────────────────────────────
def train_one_epoch(model, loader, loss_fn, optimizer,
scaler, device, cfg, epoch):
model.train()
total = 0.0
nb = 0
for images, targets, weights, _ in tqdm(
loader, desc=f"Train E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip)
scaler.step(optimizer)
scaler.update()
total += loss.item()
nb += 1
return total / nb
def validate(model, loader, loss_fn, device, cfg,
collect_attn=False, n_attn=8, epoch=0):
model.eval()
total = 0.0
nb = 0
all_preds, all_targets, all_weights = [], [], []
attn_imgs, all_layers_list, attn_ids = [], [], []
attn_done = False
with torch.no_grad():
for images, targets, weights, image_ids in tqdm(
loader, desc=f"Val E{epoch}", leave=False
):
images = images.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
weights = weights.to(device, non_blocking=True)
with autocast("cuda", enabled=cfg.training.mixed_precision):
logits = model(images)
loss, _ = loss_fn(logits, targets, weights)
total += loss.item()
nb += 1
p, t, w = predictions_to_numpy(logits, targets, weights)
all_preds.append(p)
all_targets.append(t)
all_weights.append(w)
if collect_attn and not attn_done:
all_layers = model.get_all_attention_weights()
if all_layers is not None:
n = min(n_attn, images.shape[0])
attn_imgs.append(images[:n].cpu())
all_layers_list.append([l[:n].cpu() for l in all_layers])
attn_ids.extend([int(i) for i in image_ids[:n]])
if len(attn_ids) >= n_attn:
attn_done = True
all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)
all_weights = np.concatenate(all_weights)
metrics = compute_metrics(all_preds, all_targets, all_weights)
val_logs = {"val/loss_total": total / nb}
val_logs.update({f"val/{k}": v for k, v in metrics.items()})
val_logs["val/reached_mae_w050"] = metrics.get("mae_w050/conditional_avg", 0)
attn_data = None
if collect_attn and attn_imgs:
attn_data = (
torch.cat(attn_imgs, dim=0),
[torch.cat([b[li] for b in all_layers_list], dim=0)
for li in range(len(all_layers_list[0]))],
attn_ids,
)
return val_logs, attn_data
# ─────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
base_cfg = OmegaConf.load("configs/base.yaml")
exp_cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(base_cfg, exp_cfg)
set_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log.info("Device: %s", device)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True)
Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True)
checkpoint_path = str(
Path(cfg.outputs.checkpoint_dir) / f"best_{cfg.experiment_name}.pt"
)
history_path = str(
Path(cfg.outputs.log_dir) / f"training_{cfg.experiment_name}_history.csv"
)
if cfg.wandb.enabled:
wandb.init(
project=cfg.wandb.project,
name=cfg.experiment_name,
config=OmegaConf.to_container(cfg, resolve=True),
)
log.info("Building dataloaders...")
train_loader, val_loader, _ = build_dataloaders(cfg)
log.info("Building model...")
model = build_model(cfg).to(device)
loss_fn = HierarchicalLoss(cfg)
optimizer = torch.optim.AdamW(
[
{"params": model.backbone.parameters(),
"lr": cfg.training.learning_rate * 0.1},
{"params": model.head.parameters(),
"lr": cfg.training.learning_rate},
],
weight_decay=cfg.training.weight_decay,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min
)
scaler = GradScaler("cuda")
early_stop = EarlyStopping(
patience = cfg.early_stopping.patience,
min_delta = cfg.early_stopping.min_delta,
checkpoint_path = checkpoint_path,
)
log.info("Starting training: %s", cfg.experiment_name)
history = []
for epoch in range(1, cfg.training.epochs + 1):
train_loss = train_one_epoch(
model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch
)
collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0)
val_logs, attn_data = validate(
model, val_loader, loss_fn, device, cfg,
collect_attn=collect_attn,
n_attn=cfg.wandb.n_attention_samples,
epoch=epoch,
)
scheduler.step()
lr = scheduler.get_last_lr()[0]
val_mae = val_logs.get("val/mae/weighted_avg", 0)
val_loss = val_logs["val/loss_total"]
reached = val_logs.get("val/reached_mae_w050", 0)
log.info(
"Epoch %d train=%.4f val=%.4f mae=%.4f reached_mae=%.4f lr=%.2e",
epoch, train_loss, val_loss, val_mae, reached, lr,
)
history.append({
"epoch" : epoch,
"train_loss" : train_loss,
"val_loss" : val_loss,
"val_mae" : val_mae,
"reached_mae": reached,
"lr" : lr,
})
if cfg.wandb.enabled:
log_dict = {
"train/loss": train_loss,
**val_logs,
"lr": lr, "epoch": epoch,
}
if attn_data is not None:
import matplotlib.pyplot as plt
imgs, layers, ids = attn_data
fig = plot_attention_grid(
imgs, layers, ids,
save_path=(
f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/"
f"attn_epoch{epoch:03d}.png"
),
n_cols=4, rollout_mode="full",
)
log_dict["attention/rollout_full"] = wandb.Image(fig)
plt.close(fig)
wandb.log(log_dict, step=epoch)
if early_stop.step(val_loss, model, epoch):
log.info("Early stopping at epoch %d best=%d loss=%.6f",
epoch, early_stop.best_epoch, early_stop.best_loss)
break
# Save history
pd.DataFrame(history).to_csv(history_path, index=False)
log.info("Saved history: %s", history_path)
early_stop.restore_best(model)
if cfg.wandb.enabled:
wandb.finish()
log.info("Done. Best checkpoint: %s", checkpoint_path)
if __name__ == "__main__":
main()