Spaces:
Running
Running
| """ | |
| src/train.py | |
| ------------ | |
| Main training loop for the proposed hierarchical probabilistic ViT | |
| regression model on Galaxy Zoo 2. | |
| Model : GalaxyViT (ViT-Base/16 + linear head) | |
| Loss : HierarchicalLoss (KL + MSE, Ξ»=0.5 each) | |
| Scheduler: CosineAnnealingLR | |
| Dropout : 0.3 (increased from 0.1 β see base.yaml rationale) | |
| Saves | |
| ----- | |
| outputs/checkpoints/best_<experiment_name>.pt β best checkpoint | |
| outputs/logs/training_<experiment_name>_history.csv β epoch history | |
| Usage | |
| ----- | |
| cd ~/galaxy | |
| nohup python -m src.train --config configs/full_train.yaml \ | |
| > outputs/logs/train_full.log 2>&1 & | |
| echo "PID: $!" | |
| """ | |
| import argparse | |
| import logging | |
| import random | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.amp import autocast, GradScaler | |
| from omegaconf import OmegaConf | |
| import pandas as pd | |
| import wandb | |
| from tqdm import tqdm | |
| from src.dataset import build_dataloaders | |
| from src.loss import HierarchicalLoss | |
| from src.metrics import compute_metrics, predictions_to_numpy | |
| from src.model import build_model | |
| from src.attention_viz import plot_attention_grid | |
| logging.basicConfig( | |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", | |
| datefmt="%H:%M:%S", level=logging.INFO, stream=sys.stdout, | |
| ) | |
| log = logging.getLogger("train") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Utilities | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def set_seed(seed: int): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| class EarlyStopping: | |
| def __init__(self, patience, min_delta, checkpoint_path): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.checkpoint_path = checkpoint_path | |
| self.best_loss = float("inf") | |
| self.counter = 0 | |
| self.best_epoch = 0 | |
| def step(self, val_loss, model, epoch) -> bool: | |
| if val_loss < self.best_loss - self.min_delta: | |
| self.best_loss = val_loss | |
| self.counter = 0 | |
| self.best_epoch = epoch | |
| torch.save( | |
| {"epoch": epoch, "model_state": model.state_dict(), | |
| "val_loss": val_loss}, | |
| self.checkpoint_path, | |
| ) | |
| log.info(" [ckpt] saved epoch=%d val_loss=%.6f", epoch, val_loss) | |
| else: | |
| self.counter += 1 | |
| log.info(" [early_stop] %d/%d best=%.6f", | |
| self.counter, self.patience, self.best_loss) | |
| return self.counter >= self.patience | |
| def restore_best(self, model): | |
| ckpt = torch.load(self.checkpoint_path, map_location="cpu", | |
| weights_only=True) | |
| model.load_state_dict(ckpt["model_state"]) | |
| log.info("Restored best weights epoch=%d val_loss=%.6f", | |
| ckpt["epoch"], ckpt["val_loss"]) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Training / validation steps | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def train_one_epoch(model, loader, loss_fn, optimizer, | |
| scaler, device, cfg, epoch): | |
| model.train() | |
| total = 0.0 | |
| nb = 0 | |
| for images, targets, weights, _ in tqdm( | |
| loader, desc=f"Train E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.training.grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| total += loss.item() | |
| nb += 1 | |
| return total / nb | |
| def validate(model, loader, loss_fn, device, cfg, | |
| collect_attn=False, n_attn=8, epoch=0): | |
| model.eval() | |
| total = 0.0 | |
| nb = 0 | |
| all_preds, all_targets, all_weights = [], [], [] | |
| attn_imgs, all_layers_list, attn_ids = [], [], [] | |
| attn_done = False | |
| with torch.no_grad(): | |
| for images, targets, weights, image_ids in tqdm( | |
| loader, desc=f"Val E{epoch}", leave=False | |
| ): | |
| images = images.to(device, non_blocking=True) | |
| targets = targets.to(device, non_blocking=True) | |
| weights = weights.to(device, non_blocking=True) | |
| with autocast("cuda", enabled=cfg.training.mixed_precision): | |
| logits = model(images) | |
| loss, _ = loss_fn(logits, targets, weights) | |
| total += loss.item() | |
| nb += 1 | |
| p, t, w = predictions_to_numpy(logits, targets, weights) | |
| all_preds.append(p) | |
| all_targets.append(t) | |
| all_weights.append(w) | |
| if collect_attn and not attn_done: | |
| all_layers = model.get_all_attention_weights() | |
| if all_layers is not None: | |
| n = min(n_attn, images.shape[0]) | |
| attn_imgs.append(images[:n].cpu()) | |
| all_layers_list.append([l[:n].cpu() for l in all_layers]) | |
| attn_ids.extend([int(i) for i in image_ids[:n]]) | |
| if len(attn_ids) >= n_attn: | |
| attn_done = True | |
| all_preds = np.concatenate(all_preds) | |
| all_targets = np.concatenate(all_targets) | |
| all_weights = np.concatenate(all_weights) | |
| metrics = compute_metrics(all_preds, all_targets, all_weights) | |
| val_logs = {"val/loss_total": total / nb} | |
| val_logs.update({f"val/{k}": v for k, v in metrics.items()}) | |
| val_logs["val/reached_mae_w050"] = metrics.get("mae_w050/conditional_avg", 0) | |
| attn_data = None | |
| if collect_attn and attn_imgs: | |
| attn_data = ( | |
| torch.cat(attn_imgs, dim=0), | |
| [torch.cat([b[li] for b in all_layers_list], dim=0) | |
| for li in range(len(all_layers_list[0]))], | |
| attn_ids, | |
| ) | |
| return val_logs, attn_data | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True) | |
| args = parser.parse_args() | |
| base_cfg = OmegaConf.load("configs/base.yaml") | |
| exp_cfg = OmegaConf.load(args.config) | |
| cfg = OmegaConf.merge(base_cfg, exp_cfg) | |
| set_seed(cfg.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| log.info("Device: %s", device) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| Path(cfg.outputs.checkpoint_dir).mkdir(parents=True, exist_ok=True) | |
| Path(cfg.outputs.figures_dir).mkdir(parents=True, exist_ok=True) | |
| Path(cfg.outputs.log_dir).mkdir(parents=True, exist_ok=True) | |
| checkpoint_path = str( | |
| Path(cfg.outputs.checkpoint_dir) / f"best_{cfg.experiment_name}.pt" | |
| ) | |
| history_path = str( | |
| Path(cfg.outputs.log_dir) / f"training_{cfg.experiment_name}_history.csv" | |
| ) | |
| if cfg.wandb.enabled: | |
| wandb.init( | |
| project=cfg.wandb.project, | |
| name=cfg.experiment_name, | |
| config=OmegaConf.to_container(cfg, resolve=True), | |
| ) | |
| log.info("Building dataloaders...") | |
| train_loader, val_loader, _ = build_dataloaders(cfg) | |
| log.info("Building model...") | |
| model = build_model(cfg).to(device) | |
| loss_fn = HierarchicalLoss(cfg) | |
| optimizer = torch.optim.AdamW( | |
| [ | |
| {"params": model.backbone.parameters(), | |
| "lr": cfg.training.learning_rate * 0.1}, | |
| {"params": model.head.parameters(), | |
| "lr": cfg.training.learning_rate}, | |
| ], | |
| weight_decay=cfg.training.weight_decay, | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=cfg.scheduler.T_max, eta_min=cfg.scheduler.eta_min | |
| ) | |
| scaler = GradScaler("cuda") | |
| early_stop = EarlyStopping( | |
| patience = cfg.early_stopping.patience, | |
| min_delta = cfg.early_stopping.min_delta, | |
| checkpoint_path = checkpoint_path, | |
| ) | |
| log.info("Starting training: %s", cfg.experiment_name) | |
| history = [] | |
| for epoch in range(1, cfg.training.epochs + 1): | |
| train_loss = train_one_epoch( | |
| model, train_loader, loss_fn, optimizer, scaler, device, cfg, epoch | |
| ) | |
| collect_attn = (epoch % cfg.wandb.log_attention_every_n_epochs == 0) | |
| val_logs, attn_data = validate( | |
| model, val_loader, loss_fn, device, cfg, | |
| collect_attn=collect_attn, | |
| n_attn=cfg.wandb.n_attention_samples, | |
| epoch=epoch, | |
| ) | |
| scheduler.step() | |
| lr = scheduler.get_last_lr()[0] | |
| val_mae = val_logs.get("val/mae/weighted_avg", 0) | |
| val_loss = val_logs["val/loss_total"] | |
| reached = val_logs.get("val/reached_mae_w050", 0) | |
| log.info( | |
| "Epoch %d train=%.4f val=%.4f mae=%.4f reached_mae=%.4f lr=%.2e", | |
| epoch, train_loss, val_loss, val_mae, reached, lr, | |
| ) | |
| history.append({ | |
| "epoch" : epoch, | |
| "train_loss" : train_loss, | |
| "val_loss" : val_loss, | |
| "val_mae" : val_mae, | |
| "reached_mae": reached, | |
| "lr" : lr, | |
| }) | |
| if cfg.wandb.enabled: | |
| log_dict = { | |
| "train/loss": train_loss, | |
| **val_logs, | |
| "lr": lr, "epoch": epoch, | |
| } | |
| if attn_data is not None: | |
| import matplotlib.pyplot as plt | |
| imgs, layers, ids = attn_data | |
| fig = plot_attention_grid( | |
| imgs, layers, ids, | |
| save_path=( | |
| f"{cfg.outputs.figures_dir}/{cfg.experiment_name}/" | |
| f"attn_epoch{epoch:03d}.png" | |
| ), | |
| n_cols=4, rollout_mode="full", | |
| ) | |
| log_dict["attention/rollout_full"] = wandb.Image(fig) | |
| plt.close(fig) | |
| wandb.log(log_dict, step=epoch) | |
| if early_stop.step(val_loss, model, epoch): | |
| log.info("Early stopping at epoch %d best=%d loss=%.6f", | |
| epoch, early_stop.best_epoch, early_stop.best_loss) | |
| break | |
| # Save history | |
| pd.DataFrame(history).to_csv(history_path, index=False) | |
| log.info("Saved history: %s", history_path) | |
| early_stop.restore_best(model) | |
| if cfg.wandb.enabled: | |
| wandb.finish() | |
| log.info("Done. Best checkpoint: %s", checkpoint_path) | |
| if __name__ == "__main__": | |
| main() | |