""" src/train.py ------------ Main training loop for the proposed hierarchical probabilistic ViT regression model on Galaxy Zoo 2. Model : GalaxyViT (ViT-Base/16 + linear head) Loss : HierarchicalLoss (KL + MSE, λ=0.5 each) Scheduler: CosineAnnealingLR Dropout : 0.3 (increased from 0.1 — see base.yaml rationale) Saves ----- outputs/checkpoints/best_.pt — best checkpoint outputs/logs/training__history.csv — epoch history Usage ----- cd ~/galaxy nohup python -m src.train --config configs/full_train.yaml \ > outputs/logs/train_full.log 2>&1 & echo "PID: $!" """ import argparse import logging import random import sys from pathlib import Path import numpy as np import torch import torch.nn as nn from torch.amp import autocast, GradScaler from omegaconf import OmegaConf import pandas as pd import wandb from tqdm import tqdm from src.dataset import build_dataloaders from src.loss import HierarchicalLoss from src.metrics import compute_metrics, predictions_to_numpy from src.model import build_model from src.attention_viz import plot_attention_grid logging.basicConfig( format="%(asctime)s %(levelname)s %(name)s %(message)s", datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, ) log = logging.getLogger("train") # ───────────────────────────────────────────────────────────── # Utilities # ───────────────────────────────────────────────────────────── def set_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False class EarlyStopping: def __init__(self, patience, min_delta, checkpoint_path): self.patience = patience self.min_delta = min_delta self.checkpoint_path = checkpoint_path self.best_loss = float("inf") self.counter = 0 self.best_epoch = 0 def step(self, val_loss, model, epoch) -> bool: if val_loss < self.best_loss - self.min_delta: self.best_loss = val_loss self.counter = 0 self.best_epoch = epoch torch.save( {"epoch": epoch, "model_state": model.state_dict(), "val_loss": val_loss}, self.checkpoint_path, ) log.info(" [ckpt] saved epoch=%d val_loss=%.6f", epoch, val_loss) else: self.counter += 1 log.info(" [early_stop] %d/%d best=%.6f", self.counter, self.patience, self.best_loss) return self.counter >= self.patience def restore_best(self, model): ckpt = torch.load(self.checkpoint_path, map_location="cpu", weights_only=True) model.load_state_dict(ckpt["model_state"]) log.info("Restored best weights epoch=%d val_loss=%.6f", ckpt["epoch"], ckpt["val_loss"]) # ───────────────────────────────────────────────────────────── # Training / validation steps # ───────────────────────────────────────────────────────────── def train_one_epoch(model, loader, loss_fn, optimizer, scaler, device, cfg, epoch): model.train() total = 0.0 nb = 0 for images, targets, weights, _ in tqdm( loader, desc=f"Train E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) loss, _ = loss_fn(logits, targets, weights) scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) scaler.step(optimizer) scaler.update() total += loss.item() nb += 1 return total / nb def validate(model, loader, loss_fn, device, cfg, collect_attn=False, n_attn=8, epoch=0): model.eval() total = 0.0 nb = 0 all_preds, all_targets, all_weights = [], [], [] attn_imgs, all_layers_list, attn_ids = [], [], [] attn_done = False with torch.no_grad(): for images, targets, weights, image_ids in tqdm( loader, desc=f"Val E{epoch}", leave=False ): images = images.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) weights = weights.to(device, non_blocking=True) with autocast("cuda", enabled=cfg.training.mixed_precision): logits = model(images) loss, _ = loss_fn(logits, targets, weights) total += loss.item() nb += 1 p, t, w = predictions_to_numpy(logits, targets, weights) all_preds.append(p) all_targets.append(t) all_weights.append(w) if collect_attn and not attn_done: all_layers = model.get_all_attention_weights() if all_layers is not None: n = min(n_attn, images.shape[0]) attn_imgs.append(images[:n].cpu()) all_layers_list.append([l[:n].cpu() for l in all_layers]) attn_ids.extend([int(i) for i in image_ids[:n]]) if len(attn_ids) >= n_attn: attn_done = True all_preds = np.concatenate(all_preds) all_targets = np.concatenate(all_targets) all_weights = np.concatenate(all_weights) metrics = compute_metrics(all_preds, all_targets, all_weights) val_logs = {"val/loss_total": total / nb} val_logs.update({f"val/{k}": v for k, v in metrics.items()}) val_logs["val/reached_mae_w050"] = metrics.get("mae_w050/conditional_avg", 0) attn_data = None if collect_attn and attn_imgs: attn_data = ( torch.cat(attn_imgs, dim=0), [torch.cat([b[li] for b in all_layers_list], dim=0) for li in range(len(all_layers_list[0]))], attn_ids, ) return val_logs, attn_data # ───────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser() parser.add_argument("--config", required=True) args = parser.parse_args() base_cfg = OmegaConf.load("configs/base.yaml") exp_cfg = OmegaConf.load(args.config) cfg = OmegaConf.merge(base_cfg, exp_cfg) set_seed(cfg.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.info("Device: %s", device) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True) Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) checkpoint_path = str( Path(cfg.outputs.checkpoint_dir) / f"best_{cfg.experiment_name}.pt" ) history_path = str( Path(cfg.outputs.log_dir) / f"training_{cfg.experiment_name}_history.csv" ) if cfg.wandb.enabled: wandb.init( project=cfg.wandb.project, name=cfg.experiment_name, config=OmegaConf.to_container(cfg, resolve=True), ) log.info("Building dataloaders...") train_loader, val_loader, _ = build_dataloaders(cfg) log.info("Building model...") model = build_model(cfg).to(device) loss_fn = HierarchicalLoss(cfg) optimizer = torch.optim.AdamW( [ {"params": model.backbone.parameters(), "lr": cfg.training.learning_rate * 0.1}, {"params": model.head.parameters(), "lr": cfg.training.learning_rate}, ], weight_decay=cfg.training.weight_decay, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min ) scaler = GradScaler("cuda") early_stop = EarlyStopping( patience = cfg.early_stopping.patience, min_delta = cfg.early_stopping.min_delta, checkpoint_path = checkpoint_path, ) log.info("Starting training: %s", cfg.experiment_name) history = [] for epoch in range(1, cfg.training.epochs + 1): train_loss = train_one_epoch( model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch ) collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0) val_logs, attn_data = validate( model, val_loader, loss_fn, device, cfg, collect_attn=collect_attn, n_attn=cfg.wandb.n_attention_samples, epoch=epoch, ) scheduler.step() lr = scheduler.get_last_lr()[0] val_mae = val_logs.get("val/mae/weighted_avg", 0) val_loss = val_logs["val/loss_total"] reached = val_logs.get("val/reached_mae_w050", 0) log.info( "Epoch %d train=%.4f val=%.4f mae=%.4f reached_mae=%.4f lr=%.2e", epoch, train_loss, val_loss, val_mae, reached, lr, ) history.append({ "epoch" : epoch, "train_loss" : train_loss, "val_loss" : val_loss, "val_mae" : val_mae, "reached_mae": reached, "lr" : lr, }) if cfg.wandb.enabled: log_dict = { "train/loss": train_loss, **val_logs, "lr": lr, "epoch": epoch, } if attn_data is not None: import matplotlib.pyplot as plt imgs, layers, ids = attn_data fig = plot_attention_grid( imgs, layers, ids, save_path=( f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/" f"attn_epoch{epoch:03d}.png" ), n_cols=4, rollout_mode="full", ) log_dict["attention/rollout_full"] = wandb.Image(fig) plt.close(fig) wandb.log(log_dict, step=epoch) if early_stop.step(val_loss, model, epoch): log.info("Early stopping at epoch %d best=%d loss=%.6f", epoch, early_stop.best_epoch, early_stop.best_loss) break # Save history pd.DataFrame(history).to_csv(history_path, index=False) log.info("Saved history: %s", history_path) early_stop.restore_best(model) if cfg.wandb.enabled: wandb.finish() log.info("Done. Best checkpoint: %s", checkpoint_path) if __name__ == "__main__": main()