Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
eb725f8 verified | """ | |
| Training script for conditional diffusion on CAMELS LH (6 cosmological parameters). | |
| Same training theory as DDPM_HI_Emulation_improved (2-label): DDPM noise prediction, | |
| DDIM sampling, ConditionalUNet with time + label embeddings, label z-score from train split, | |
| EMA, optional AMP, cosine LR, early stopping. | |
| Changes from original: | |
| - EMA weights are now applied before validation and sampling | |
| - Training args are saved to args.txt for evaluation script | |
| - Fixed --normalize_labels and --use_ddim flags (were un-disableable) | |
| - Added mixed-precision (AMP) training support | |
| - Fixed loss averaging to be per-sample rather than per-batch | |
| - Added weights_only=True to torch.load for security | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import time | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.optim as optim | |
| from tqdm import tqdm | |
| from dataset_conditional import DEFAULT_DATA_DIR, get_conditional_dataloaders | |
| from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion | |
| from unet_conditional import ConditionalUNet | |
| # Weights & Biases (optional) | |
| try: | |
| import wandb | |
| WANDB_AVAILABLE = True | |
| except ImportError: | |
| WANDB_AVAILABLE = False | |
| print("Warning: wandb not available. Install with: pip install wandb") | |
| class EMA: | |
| """Exponential Moving Average for model parameters""" | |
| def __init__(self, model, decay=0.9999): | |
| self.model = model | |
| self.decay = decay | |
| self.shadow = {} | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| self.shadow[name] = param.data.clone() | |
| def update(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad: | |
| self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data | |
| def apply_shadow(self): | |
| self.backup = { | |
| name: param.data.clone() for name, param in self.model.named_parameters() if param.requires_grad | |
| } | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad: | |
| param.data = self.shadow[name] | |
| def restore(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad: | |
| param.data = self.backup[name] | |
| self.backup = {} | |
| def train_epoch(model, dataloader, optimizer, device, epoch, ema=None, use_wandb=False, scaler=None): | |
| model.train() | |
| total_loss = 0.0 | |
| total_samples = 0 | |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch}") | |
| for batch_idx, (images, labels) in enumerate(pbar): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| batch_size = images.shape[0] | |
| optimizer.zero_grad() | |
| if scaler is not None: | |
| with torch.amp.autocast("cuda"): | |
| loss = model.get_loss(images, labels) | |
| scaler.scale(loss).backward() | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss = model.get_loss(images, labels) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| if ema is not None: | |
| ema.update() | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| pbar.set_postfix({"loss": f"{loss.item():.4f}"}) | |
| if use_wandb and batch_idx % 10 == 0: | |
| wandb.log({"batch_loss": loss.item(), "epoch": epoch, "batch": epoch * len(dataloader) + batch_idx}) | |
| return total_loss / total_samples | |
| def validate(model, dataloader, device): | |
| model.eval() | |
| total_loss = 0.0 | |
| total_samples = 0 | |
| with torch.no_grad(): | |
| for images, labels in tqdm(dataloader, desc="Validating"): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| batch_size = images.shape[0] | |
| loss = model.get_loss(images, labels) | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| return total_loss / total_samples | |
| def save_checkpoint(model, optimizer, ema, epoch, loss, save_dir, is_best=False, last_improvement_epoch=None, scheduler=None): | |
| checkpoint = { | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "loss": loss, | |
| } | |
| if ema is not None: | |
| checkpoint["ema_shadow"] = ema.shadow | |
| if last_improvement_epoch is not None: | |
| checkpoint["last_improvement_epoch"] = last_improvement_epoch | |
| if scheduler is not None: | |
| checkpoint["scheduler_state_dict"] = scheduler.state_dict() | |
| torch.save(checkpoint, os.path.join(save_dir, "checkpoint_latest.pt")) | |
| if is_best: | |
| torch.save(checkpoint, os.path.join(save_dir, "best_model.pt")) | |
| print(f"Saved best model at epoch {epoch+1}") | |
| if (epoch + 1) % 20 == 0: | |
| torch.save(checkpoint, os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt")) | |
| print(f"Saved checkpoint at epoch {epoch+1}") | |
| def sample_images(model, diffusion, device, save_path, test_labels, ema=None, n_samples=8, epoch=0, use_ddim=True, ddim_steps=50, use_wandb=False): | |
| if ema is not None: | |
| ema.apply_shadow() | |
| model.eval() | |
| labels = test_labels[:n_samples].to(device) | |
| with torch.no_grad(): | |
| samples = diffusion.sample( | |
| model, | |
| labels=labels, | |
| channels=1, | |
| height=256, | |
| width=256, | |
| device=device, | |
| progress=True, | |
| use_ddim=use_ddim, | |
| ddim_steps=ddim_steps, | |
| eta=0.0, | |
| ) | |
| if ema is not None: | |
| ema.restore() | |
| n_cols = min(n_samples, 4) | |
| n_rows = (n_samples + n_cols - 1) // n_cols | |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4.5 * n_rows)) | |
| if n_rows == 1 and n_cols == 1: | |
| axes = np.array([[axes]]) | |
| elif n_rows == 1: | |
| axes = axes[np.newaxis, :] | |
| elif n_cols == 1: | |
| axes = axes[:, np.newaxis] | |
| for i in range(n_rows * n_cols): | |
| ax = axes[i // n_cols, i % n_cols] | |
| if i < n_samples: | |
| img = samples[i, 0].cpu().numpy() | |
| label_vals = labels[i].cpu().tolist() | |
| label_str = ", ".join(f"{v:.2f}" for v in label_vals) | |
| ax.imshow(img, vmin=-1, vmax=1) | |
| ax.set_title(label_str, fontsize=10) | |
| ax.axis("off") | |
| plt.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=14) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150, bbox_inches="tight") | |
| if use_wandb: | |
| wandb.log({"generated_samples": wandb.Image(save_path), "epoch": epoch}) | |
| plt.close() | |
| print(f"Saved samples to {save_path}") | |
| def save_training_args(args, output_dir): | |
| """Save training arguments so the evaluation script can reconstruct the model.""" | |
| args_path = os.path.join(output_dir, "args.txt") | |
| with open(args_path, "w", encoding="utf-8") as f: | |
| for key, value in vars(args).items(): | |
| f.write(f"{key}: {value}\n") | |
| args_json_path = os.path.join(output_dir, "args.json") | |
| with open(args_json_path, "w", encoding="utf-8") as f: | |
| json.dump(vars(args), f, indent=2) | |
| print(f"Saved training args to {args_path} and {args_json_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train conditional diffusion (LH 6-parameter)") | |
| # Model | |
| parser.add_argument("--label_dim", type=int, default=6) | |
| parser.add_argument("--base_channels", type=int, default=64) | |
| parser.add_argument("--channel_multipliers", type=int, nargs="+", default=[1, 2, 4, 8]) | |
| parser.add_argument("--attention_levels", type=int, nargs="+", default=[2, 3]) | |
| parser.add_argument("--dropout", type=float, default=0.1) | |
| # Diffusion | |
| parser.add_argument("--timesteps", type=int, default=1500) | |
| parser.add_argument("--beta_start", type=float, default=1e-4) | |
| parser.add_argument("--beta_end", type=float, default=0.02) | |
| parser.add_argument("--schedule_type", type=str, default="linear") | |
| # Training | |
| parser.add_argument("--epochs", type=int, default=100) | |
| parser.add_argument("--batch_size", type=int, default=8) | |
| parser.add_argument("--lr", type=float, default=2e-4) | |
| parser.add_argument("--ema_decay", type=float, default=0.9999) | |
| parser.add_argument("--num_workers", type=int, default=4) | |
| parser.add_argument("--early_stop_patience", type=int, default=30) | |
| parser.add_argument( | |
| "--use_amp", | |
| action="store_true", | |
| default=False, | |
| help="Enable mixed-precision training (recommended for GPU)", | |
| ) | |
| # Data | |
| parser.add_argument( | |
| "--data_dir", | |
| type=str, | |
| default=DEFAULT_DATA_DIR, | |
| help="Directory with *_LH_6.npy and *_labels_LH.npy (same rule as improved repo: e.g. .../LH_data/params_6)", | |
| ) | |
| parser.add_argument("--normalize_labels", action=argparse.BooleanOptionalAction, default=True) | |
| # Output | |
| parser.add_argument("--output_dir", type=str, default="outputs_conditional_6param") | |
| parser.add_argument("--resume", type=str, default="") | |
| parser.add_argument( | |
| "--resume_refresh_scheduler", | |
| action="store_true", | |
| help="On resume, rebuild cosine LR scheduler for --epochs (last_epoch=start-1) instead of loading saved scheduler; use when extending training beyond the original epoch count", | |
| ) | |
| parser.add_argument("--sample_every", type=int, default=10) | |
| parser.add_argument("--use_ddim", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--ddim_steps", type=int, default=50) | |
| # WandB | |
| parser.add_argument("--use_wandb", action="store_true", default=False) | |
| parser.add_argument("--wandb_project", type=str, default="ddpm_cosmology") | |
| parser.add_argument("--wandb_entity", type=str, default="") | |
| parser.add_argument("--wandb_run_name", type=str, default="") | |
| args = parser.parse_args() | |
| seed = 42 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| use_wandb = args.use_wandb and WANDB_AVAILABLE | |
| if use_wandb: | |
| run_name = args.wandb_run_name or f"conditional_diffusion_{time.strftime('%Y%m%d_%H%M%S')}" | |
| wandb.init(project=args.wandb_project, entity=args.wandb_entity or None, name=run_name, config=vars(args)) | |
| print(f"W&B run: {run_name}") | |
| timestamp = time.strftime("%Y%m%d_%H%M%S") | |
| output_dir = f"{args.output_dir}_{timestamp}" | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True) | |
| os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True) | |
| save_training_args(args, output_dir) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| scaler = torch.amp.GradScaler("cuda") if args.use_amp and torch.cuda.is_available() else None | |
| if scaler: | |
| print("Mixed-precision training enabled (AMP)") | |
| print("\nLoading data...") | |
| train_loader, val_loader, test_loader = get_conditional_dataloaders( | |
| data_dir=args.data_dir, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| normalize_labels=args.normalize_labels, | |
| label_dim=args.label_dim, | |
| ) | |
| _, test_labels = next(iter(test_loader)) | |
| print("\nCreating model...") | |
| unet = ConditionalUNet( | |
| in_channels=1, | |
| out_channels=1, | |
| label_dim=args.label_dim, | |
| base_channels=args.base_channels, | |
| channel_multipliers=args.channel_multipliers, | |
| attention_levels=args.attention_levels, | |
| dropout=args.dropout, | |
| ) | |
| diffusion = GaussianDiffusion( | |
| timesteps=args.timesteps, | |
| beta_start=args.beta_start, | |
| beta_end=args.beta_end, | |
| schedule_type=args.schedule_type, | |
| ) | |
| model = ConditionalDiffusionModel(unet, diffusion).to(device) | |
| print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") | |
| optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) | |
| ema = EMA(model, decay=args.ema_decay) | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | |
| start_epoch = 0 | |
| best_val_loss = float("inf") | |
| last_improvement_epoch = -1 | |
| if args.resume: | |
| print(f"Resuming from {args.resume}") | |
| checkpoint = torch.load(args.resume, map_location=device, weights_only=False) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| if "ema_shadow" in checkpoint: | |
| ema.shadow = checkpoint["ema_shadow"] | |
| start_epoch = checkpoint["epoch"] + 1 | |
| best_val_loss = checkpoint.get("loss", float("inf")) | |
| last_improvement_epoch = checkpoint.get("last_improvement_epoch", -1) | |
| if args.resume_refresh_scheduler: | |
| scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, T_max=args.epochs, last_epoch=start_epoch - 1 | |
| ) | |
| print( | |
| f"Rebuilt LR scheduler for extended run: T_max={args.epochs}, " | |
| f"resume at epoch {start_epoch + 1} (last_epoch={start_epoch - 1})" | |
| ) | |
| elif "scheduler_state_dict" in checkpoint: | |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) | |
| print("\nStarting training...") | |
| losses = {"train": [], "val": []} | |
| for epoch in range(start_epoch, args.epochs): | |
| train_loss = train_epoch(model, train_loader, optimizer, device, epoch, ema, use_wandb, scaler=scaler) | |
| if ema is not None: | |
| ema.apply_shadow() | |
| val_loss = validate(model, val_loader, device) | |
| if ema is not None: | |
| ema.restore() | |
| losses["train"].append(train_loss) | |
| losses["val"].append(val_loss) | |
| scheduler.step() | |
| if use_wandb: | |
| wandb.log( | |
| { | |
| "epoch": epoch + 1, | |
| "train_loss": train_loss, | |
| "val_loss": val_loss, | |
| "learning_rate": optimizer.param_groups[0]["lr"], | |
| } | |
| ) | |
| print( | |
| f"\nEpoch {epoch+1}/{args.epochs} | Train: {train_loss:.6f} | Val: {val_loss:.6f} | " | |
| f"LR: {optimizer.param_groups[0]['lr']:.6e}" | |
| ) | |
| is_best = val_loss < best_val_loss | |
| if is_best: | |
| best_val_loss = val_loss | |
| last_improvement_epoch = epoch | |
| save_checkpoint( | |
| model, | |
| optimizer, | |
| ema, | |
| epoch, | |
| val_loss, | |
| os.path.join(output_dir, "checkpoints"), | |
| is_best=is_best, | |
| last_improvement_epoch=last_improvement_epoch, | |
| scheduler=scheduler, | |
| ) | |
| if epoch - last_improvement_epoch >= args.early_stop_patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| if (epoch + 1) % args.sample_every == 0: | |
| sample_path = os.path.join(output_dir, "samples", f"samples_epoch_{epoch+1}.png") | |
| sample_images( | |
| model, | |
| diffusion, | |
| device, | |
| sample_path, | |
| test_labels, | |
| ema=ema, | |
| epoch=epoch + 1, | |
| use_ddim=args.use_ddim, | |
| ddim_steps=args.ddim_steps, | |
| use_wandb=use_wandb, | |
| ) | |
| if (epoch + 1) % 5 == 0: | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(losses["train"], label="Train Loss") | |
| plt.plot(losses["val"], label="Val Loss") | |
| plt.yscale("log") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Loss") | |
| plt.title("Training Progress") | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.savefig(os.path.join(output_dir, "losses.png"), dpi=150) | |
| plt.close() | |
| print(f"\nTraining completed! Best val loss: {best_val_loss:.6f}") | |
| print(f"Results saved to: {output_dir}") | |
| if use_wandb: | |
| wandb.finish() | |
| if __name__ == "__main__": | |
| main() | |