#!/usr/bin/env python3 """ End-to-end parameter inference for the conditional DDPM stack: training (noise-prediction MSE / ELBO surrogate), checkpointing, conditional sampling, evaluation-style metrics, and optional **VLB-based cosmological parameter constraints** following Mudur et al. (2023). Reference (parameter inference via conditional diffusion VLB): Mudur, Cuesta-Lazaro & Finkbeiner, "Cosmological Field Emulation and Parameter Inference with Diffusion Models", arXiv:2312.07534 (2023). https://arxiv.org/abs/2312.07534 They train a DDPM (Ho et al. 2020) on log density fields conditioned on (Omega_m, sigma_8), then evaluate VLB terms L_t(x_0 | theta_eval) on a grid in parameter space. The dominant term is L_0 = -log p_phi(x_0 | x_1, theta) with x_1 ~ q(x_1|x_0). They form -2 Delta ln L_hat ~ 2(L_0 - min L_0) and map marginals to approximate posteriors (68% intervals on a grid). This script implements the **L_0 approximation** (their primary reported setup) using the existing GaussianDiffusion reverse mean/variance at timestep index t=1. Full multi-t VLB sums are left as a documented extension. Note: train_conditional.py exposes hyperparameters via argparse (no separate Config dataclass). This script mirrors those fields and uses the same training utilities (EMA, AMP, grad clip inside train_epoch). """ from __future__ import annotations import argparse import json import logging import math import os import random import sys import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch import torch.optim as optim from torch.utils.data import DataLoader from dataset_conditional import get_conditional_dataloaders from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion from evaluate_conditional import ( build_model, calculate_pdf_batch, calculate_power_spectrum_batch, from_model_output, load_checkpoint, load_label_stats, load_split, load_training_config, prepare_labels_for_model, ) from train_conditional import ( EMA, save_checkpoint, save_training_args, train_epoch, validate, ) from unet_conditional import ConditionalUNet def _setup_logging(log_path: Optional[Path] = None) -> logging.Logger: log = logging.getLogger("parameter_inference_conditional") log.handlers.clear() log.setLevel(logging.INFO) fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") sh = logging.StreamHandler(sys.stdout) sh.setFormatter(fmt) log.addHandler(sh) if log_path is not None: fh = logging.FileHandler(log_path, encoding="utf-8") fh.setFormatter(fmt) log.addHandler(fh) return log def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def _infer_spatial_size(loader: DataLoader) -> Tuple[int, int]: img0, _ = loader.dataset[0] if img0.dim() == 3: _, h, w = img0.shape else: h, w = img0.shape[-2], img0.shape[-1] return int(h), int(w) def save_conditioned_sample_grid( model: ConditionalDiffusionModel, diffusion: GaussianDiffusion, labels: torch.Tensor, device: torch.device, save_path: Path, *, channels: int, height: int, width: int, ema: Optional[EMA], use_ddim: bool, ddim_steps: int, title: str = "Conditional samples", ) -> None: """Save a grid of DDPM/DDIM samples conditioned on label vectors (same idea as train_conditional.sample_images, spatial size from data).""" if ema is not None: ema.apply_shadow() unet = model.unet unet.eval() labels = labels.to(device) n_samples = labels.shape[0] with torch.no_grad(): samples = diffusion.sample( model, labels=labels, channels=channels, height=height, width=width, device=device, progress=False, use_ddim=use_ddim, ddim_steps=ddim_steps, eta=0.0, ) if ema is not None: ema.restore() n_cols = min(n_samples, 4) n_rows = (n_samples + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4.5 * n_rows)) if n_rows == 1 and n_cols == 1: axes = np.array([[axes]]) elif n_rows == 1: axes = axes[np.newaxis, :] elif n_cols == 1: axes = axes[:, np.newaxis] for i in range(n_rows * n_cols): ax = axes[i // n_cols, i % n_cols] if i < n_samples: img = samples[i, 0].cpu().numpy() label_vals = labels[i].cpu().tolist() label_str = ", ".join(f"{v:.3f}" for v in label_vals) ax.imshow(img, cmap="gray", vmin=-1, vmax=1) ax.set_title(label_str, fontsize=10) ax.axis("off") plt.suptitle(title, fontsize=14) plt.tight_layout() save_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() logging.getLogger("parameter_inference_conditional").info("Saved sample grid to %s", save_path) def _final_metrics_log( real_np: np.ndarray, gen_np: np.ndarray, log: logging.Logger, ) -> Dict[str, float]: """Compute lightweight distributional metrics (PDF / P(k) curve L2 on binned means).""" _, mean_pdf_r, _ = calculate_pdf_batch(real_np) bc, mean_pdf_g, _ = calculate_pdf_batch(gen_np) pdf_mse = float(np.mean((mean_pdf_r - mean_pdf_g) ** 2)) dk, mean_pk_r, _ = calculate_power_spectrum_batch(real_np) _, mean_pk_g, _ = calculate_power_spectrum_batch(gen_np) k_min = 1 pk_mse = float(np.mean((mean_pk_r[k_min:] - mean_pk_g[k_min:]) ** 2)) log.info("Final metric | PDF mean MSE (density bins): %.6e", pdf_mse) log.info("Final metric | P(k) mean MSE (k>0 bins): %.6e", pk_mse) return { "pdf_mean_mse": pdf_mse, "pk_mean_mse": pk_mse, "pdf_bin_centers": float(bc.size), "pk_bins": float(dk.size), } # --- VLB / L0 parameter inference (Mudur et al. 2023, arXiv:2312.07534) --- _LOG2PI = math.log(2.0 * math.pi) def _gaussian_nll_spatial_sum(x: torch.Tensor, mean: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor: """Per-batch-element NLL for diagonal Gaussian; x, mean same shape; log_var broadcastable.""" while log_var.dim() < x.dim(): log_var = log_var.unsqueeze(-1) inv = torch.exp(-log_var) nll_pix = 0.5 * ((x - mean) ** 2 * inv + log_var + _LOG2PI) return nll_pix.view(nll_pix.shape[0], -1).sum(dim=1) @torch.no_grad() def estimate_l0_nll_batch( model: ConditionalDiffusionModel, diffusion: GaussianDiffusion, x0: torch.Tensor, labels_norm: torch.Tensor, *, n_seeds: int, base_seed: int, ) -> torch.Tensor: """ Monte-Carlo average of L_0 = -log p_theta(x_0 | x_1, theta) with x_1 ~ q(x_1 | x_0) at diffusion index t=1 (lightly noised latent). """ device = x0.device b = x0.shape[0] if diffusion.timesteps < 3: raise ValueError("VLB L0 requires diffusion.timesteps >= 3 (need t=1).") t1 = torch.ones(b, device=device, dtype=torch.long) acc = torch.zeros(b, device=device) model.eval() for s in range(n_seeds): torch.manual_seed(int(base_seed + s)) if device.type == "cuda": torch.cuda.manual_seed_all(int(base_seed + s)) noise = torch.randn(x0.shape, device=device, dtype=x0.dtype) x1 = diffusion.q_sample(x0, t1, noise=noise) mean, _pv, log_var, _ = diffusion.p_mean_variance(model, x1, t1, labels_norm, clip_denoised=True) acc += _gaussian_nll_spatial_sum(x0, mean, log_var) return acc / float(n_seeds) def _build_theta_grid( theta_true: np.ndarray, half_width: float, prior_lo: np.ndarray, prior_hi: np.ndarray, n_per_dim: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """2D grid in *physical* label space (same units as .npy labels), CAMELS-style bounds.""" g0 = np.linspace( max(theta_true[0] - half_width, prior_lo[0]), min(theta_true[0] + half_width, prior_hi[0]), n_per_dim, dtype=np.float64, ) g1 = np.linspace( max(theta_true[1] - half_width, prior_lo[1]), min(theta_true[1] + half_width, prior_hi[1]), n_per_dim, dtype=np.float64, ) G0, G1 = np.meshgrid(g0, g1, indexing="ij") pts = np.stack([G0.ravel(), G1.ravel()], axis=1) return pts, g0, g1 def _delta_chi2_contour_levels_2d() -> List[float]: """Approximate Delta chi^2 thresholds for 68%, 95%, 99.7% (2 dof), Mudur-style contours.""" return [2.30, 5.99, 11.82] def _shortest_mass_interval(x: np.ndarray, w: np.ndarray, mass: float = 0.68) -> Tuple[float, float]: """Shortest interval on sorted x containing `mass` of normalized weights w.""" order = np.argsort(x) xs = x[order] ws = w[order].astype(np.float64) ws = ws / (ws.sum() + 1e-30) c = np.concatenate([[0.0], np.cumsum(ws)]) n = len(xs) best_lo, best_hi = float(xs[0]), float(xs[-1]) best_w = float("inf") for i in range(n): for j in range(i + 1, n + 1): if c[j] - c[i] >= mass - 1e-9: lo, hi = float(xs[i]), float(xs[j - 1]) if hi - lo < best_w: best_w = hi - lo best_lo, best_hi = lo, hi break return best_lo, best_hi def _vlb_posterior_summaries( L0: np.ndarray, g0: np.ndarray, g1: np.ndarray, ) -> Dict[str, Any]: """Convert L0 grid to unnormalized likelihood exp(-(L0-min)), marginals, MAP, 68% intervals.""" dchi2 = 2.0 * (L0 - L0.min()) log_like = -0.5 * dchi2 w = np.exp(log_like - log_like.max()) w = w / (w.sum() + 1e-30) n0, n1 = len(g0), len(g1) W = w.reshape(n0, n1) m0 = W.sum(axis=1) m1 = W.sum(axis=0) m0 = m0 / (m0.sum() + 1e-30) m1 = m1 / (m1.sum() + 1e-30) map_i, map_j = np.unravel_index(int(np.argmax(W)), W.shape) theta_map = (float(g0[map_i]), float(g1[map_j])) int0 = _shortest_mass_interval(g0, m0, 0.68) int1 = _shortest_mass_interval(g1, m1, 0.68) return { "delta_chi2": dchi2.reshape(n0, n1).tolist(), "theta_map_omega_m": theta_map[0], "theta_map_sigma8": theta_map[1], "marginal_68_omega_m": list(int0), "marginal_68_sigma8": list(int1), } def save_vlb_corner_figure( g0: np.ndarray, g1: np.ndarray, L0: np.ndarray, theta_true: np.ndarray, out_path: Path, *, names: Tuple[str, str] = (r"$\Omega_{\rm m}$", r"$\sigma_8$"), ) -> None: """2D contours of Delta = 2(L0 - min L0) with truth cross; marginals via KDE-free histogram of grid.""" n0, n1 = len(g0), len(g1) D = (2.0 * (L0 - L0.min())).reshape(n0, n1) # x-axis: sigma_8 (g1), y-axis: Omega_m (g0); Z[i,j] at (g0[i], g1[j]) G0_2d, G1_2d = np.meshgrid(g0, g1, indexing="ij") fig = plt.figure(figsize=(7.0, 6.8)) from matplotlib.gridspec import GridSpec gs = GridSpec(2, 2, figure=fig, width_ratios=[4, 1.1], height_ratios=[1, 4], wspace=0.12, hspace=0.12) ax_j = fig.add_subplot(gs[1, 0]) ax_mx = fig.add_subplot(gs[0, 0], sharex=ax_j) ax_my = fig.add_subplot(gs[1, 1], sharey=ax_j) ax_mx.tick_params(labelleft=False, labelbottom=False) ax_my.tick_params(labelleft=False, labelbottom=False) cf = ax_j.contourf(G1_2d, G0_2d, D, levels=28, cmap="Greys", alpha=0.9) for lev in _delta_chi2_contour_levels_2d(): ax_j.contour(G1_2d, G0_2d, D, levels=[lev], colors="C0", linewidths=1.2) ax_j.axhline(theta_true[0], color="0.35", lw=0.8, ls="--") ax_j.axvline(theta_true[1], color="0.35", lw=0.8, ls="--") ax_j.scatter([theta_true[1]], [theta_true[0]], marker="x", s=80, c="crimson", zorder=9, linewidths=2) ax_j.set_xlabel(names[1]) ax_j.set_ylabel(names[0]) for lbl in ax_j.get_xticklabels(): lbl.set_rotation(45) lbl.set_ha("right") fig.colorbar(cf, ax=ax_j, fraction=0.046, pad=0.02, label=r"$2\,[L_0 - \min L_0]$ (Mudur et al.\ proxy)") W = np.exp(-0.5 * (L0 - L0.min())) W = W.reshape(n0, n1) m_omega = W.sum(axis=1) m_sigma = W.sum(axis=0) ax_mx.plot(g1, m_sigma / (m_sigma.max() + 1e-30), color="0.2", lw=1.5) ax_my.plot(m_omega / (m_omega.max() + 1e-30), g0, color="0.2", lw=1.5) out_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(out_path, dpi=180, bbox_inches="tight", facecolor="white") plt.close() def numpy_field_to_x0_tensor(img_01: np.ndarray, device: torch.device) -> torch.Tensor: """[H,W] or [1,H,W] float in [0,1] -> [1,1,H,W] in [-1,1] as used in training.""" t = torch.from_numpy(np.asarray(img_01, dtype=np.float32)) if t.dim() == 2: t = t.unsqueeze(0) t = t * 2.0 - 1.0 return t.unsqueeze(0).to(device) def run_vlb_parameter_inference( args: argparse.Namespace, log: logging.Logger, *, output_dir: Optional[Path] = None, checkpoint_path: Optional[str] = None, training_args_path: Optional[str] = None, ) -> None: """ Mudur et al. (2023) style grid evaluation of L_0 on held-out fields. """ device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) log.info("VLB inference | device=%s", device) ta = training_args_path or args.training_args ck = checkpoint_path or args.checkpoint if ta is None or not os.path.isfile(str(ta)): raise FileNotFoundError("VLB mode requires --training_args (args.json from training).") if ck is None or not os.path.isfile(str(ck)): raise FileNotFoundError("VLB mode requires --checkpoint.") config = load_training_config(str(ta)) model = build_model(config, device) load_checkpoint(model, str(ck), device) diffusion = model.diffusion data_dir = Path(args.data_dir) label_mean, label_std = load_label_stats(data_dir) images, labels_phys = load_split(data_dir, args.vlb_split) n_fields = min(args.vlb_n_fields, len(images)) rng = np.random.default_rng(args.seed) if n_fields < len(images): pick = rng.choice(len(images), size=n_fields, replace=False) else: pick = np.arange(n_fields) prior_lo = np.array([args.vlb_prior_omega_m[0], args.vlb_prior_sigma8[0]], dtype=np.float64) prior_hi = np.array([args.vlb_prior_omega_m[1], args.vlb_prior_sigma8[1]], dtype=np.float64) out_root = Path(output_dir or args.vlb_output_dir) out_root.mkdir(parents=True, exist_ok=True) all_rows: List[Dict[str, Any]] = [] for k, idx in enumerate(pick): x0 = numpy_field_to_x0_tensor(images[idx], device) truth = labels_phys[idx].astype(np.float64) grid_pts, g0, g1 = _build_theta_grid(truth, args.vlb_half_width, prior_lo, prior_hi, args.vlb_n_grid) n_pts = grid_pts.shape[0] L0_accum = np.zeros(n_pts, dtype=np.float64) for start in range(0, n_pts, args.vlb_chunk_size): end = min(start + args.vlb_chunk_size, n_pts) chunk = grid_pts[start:end] lt = prepare_labels_for_model(chunk, label_mean, label_std).to(device) xrep = x0.expand(end - start, -1, -1, -1) L0_b = estimate_l0_nll_batch( model, diffusion, xrep, lt, n_seeds=args.vlb_l0_seeds, base_seed=args.seed + k * 10007 + start, ) L0_accum[start:end] = L0_b.detach().cpu().numpy() summ = _vlb_posterior_summaries(L0_accum, g0, g1) summ.update( { "field_index": int(idx), "theta_true_omega_m": float(truth[0]), "theta_true_sigma8": float(truth[1]), } ) all_rows.append(summ) fig_path = out_root / f"vlb_corner_field_{k}_idx{idx}.png" save_vlb_corner_figure(g0, g1, L0_accum, truth, fig_path) log.info( "VLB field %d | MAP (Om,s8)=(%.4f,%.4f) true=(%.4f,%.4f) | 68%% marg Om %s s8 %s", k, summ["theta_map_omega_m"], summ["theta_map_sigma8"], truth[0], truth[1], summ["marginal_68_omega_m"], summ["marginal_68_sigma8"], ) with open(out_root / "vlb_inference_summary.json", "w", encoding="utf-8") as f: json.dump(all_rows, f, indent=2) log.info("Wrote VLB summary to %s", out_root / "vlb_inference_summary.json") def run_training(args: argparse.Namespace, log: logging.Logger) -> str: device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) log.info("Device: %s", device) use_amp = bool(args.use_amp) and device.type == "cuda" scaler = torch.amp.GradScaler("cuda") if use_amp else None if use_amp: log.info("Mixed precision (torch.amp.GradScaler + autocast in train_epoch) enabled.") timestamp = time.strftime("%Y%m%d_%H%M%S") output_dir = f"{args.output_dir}_{timestamp}" os.makedirs(output_dir, exist_ok=True) os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True) os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True) log_path = Path(output_dir) / "training.log" _setup_logging(log_path) save_training_args(args, output_dir) pin_memory = bool(args.pin_memory) and device.type == "cuda" log.info("Loading dataloaders from %s (pin_memory=%s)", args.data_dir, pin_memory) train_loader, val_loader, test_loader = get_conditional_dataloaders( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=pin_memory, normalize_labels=args.normalize_labels, ) _, test_labels_tensor = next(iter(test_loader)) h, w = _infer_spatial_size(train_loader) channels = train_loader.dataset[0][0].shape[0] if train_loader.dataset[0][0].dim() == 3 else 1 log.info("Spatial size HxW=%dx%d, channels=%d", h, w, channels) log.info("Building ConditionalUNet + GaussianDiffusion (T=%d, schedule=%s)", args.timesteps, args.schedule_type) unet = ConditionalUNet( in_channels=channels, out_channels=channels, label_dim=args.label_dim, base_channels=args.base_channels, channel_multipliers=args.channel_multipliers, attention_levels=args.attention_levels, dropout=args.dropout, ) diffusion = GaussianDiffusion( timesteps=args.timesteps, beta_start=args.beta_start, beta_end=args.beta_end, schedule_type=args.schedule_type, ) model = ConditionalDiffusionModel(unet, diffusion).to(device) n_params = sum(p.numel() for p in model.parameters()) log.info("Trainable parameters: %s", f"{n_params:,}") optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) ema = EMA(model, decay=args.ema_decay) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) start_epoch = 0 best_val_loss = float("inf") last_improvement_epoch = -1 if args.resume: log.info("Resuming from %s", args.resume) checkpoint = torch.load(args.resume, map_location=device, weights_only=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if "ema_shadow" in checkpoint: ema.shadow = checkpoint["ema_shadow"] if "scheduler_state_dict" in checkpoint: scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) start_epoch = int(checkpoint["epoch"]) + 1 best_val_loss = float(checkpoint.get("loss", float("inf"))) last_improvement_epoch = int(checkpoint.get("last_improvement_epoch", -1)) losses_train: list[float] = [] losses_val: list[float] = [] for epoch in range(start_epoch, args.epochs): train_loss = train_epoch( model, train_loader, optimizer, device, epoch, ema=ema, use_wandb=False, scaler=scaler ) if ema is not None: ema.apply_shadow() val_loss = validate(model, val_loader, device) if ema is not None: ema.restore() losses_train.append(train_loss) losses_val.append(val_loss) scheduler.step() log.info( "Epoch %d/%d | train_loss=%.6f | val_loss=%.6f | lr=%.6e", epoch + 1, args.epochs, train_loss, val_loss, optimizer.param_groups[0]["lr"], ) is_best = val_loss < best_val_loss if is_best: best_val_loss = val_loss last_improvement_epoch = epoch save_checkpoint( model, optimizer, ema, epoch, val_loss, os.path.join(output_dir, "checkpoints"), is_best=is_best, last_improvement_epoch=last_improvement_epoch, scheduler=scheduler, ) if epoch - last_improvement_epoch >= args.early_stop_patience: log.info("Early stopping at epoch %d", epoch + 1) break if (epoch + 1) % args.sample_every == 0: sample_path = Path(output_dir) / "samples" / f"samples_epoch_{epoch+1}.png" save_conditioned_sample_grid( model, diffusion, test_labels_tensor[: args.n_preview_samples], device, sample_path, channels=channels, height=h, width=w, ema=ema, use_ddim=args.use_ddim, ddim_steps=args.ddim_steps, title=f"Generated samples — epoch {epoch+1}", ) if (epoch + 1) % 5 == 0: plt.figure(figsize=(10, 5)) plt.plot(losses_train, label="Train") plt.plot(losses_val, label="Val") plt.yscale("log") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Training / validation noise-prediction loss") plt.legend() plt.grid(True, alpha=0.3) plt.savefig(Path(output_dir) / "losses.png", dpi=150) plt.close() log.info("Training finished. Best validation loss: %.6f", best_val_loss) # --- Post-training: best checkpoint + conditional grid + scalar metrics --- best_ckpt = Path(output_dir) / "checkpoints" / "best_model.pt" if not best_ckpt.is_file(): best_ckpt = Path(output_dir) / "checkpoints" / "checkpoint_latest.pt" args_json = Path(output_dir) / "args.json" config = load_training_config(str(args_json)) eval_model = build_model(config, device) load_checkpoint(eval_model, str(best_ckpt), device) eval_diffusion = eval_model.diffusion grid_path = Path(output_dir) / "generated_samples_conditional.png" save_conditioned_sample_grid( eval_model, eval_diffusion, test_labels_tensor[: args.n_preview_samples], device, grid_path, channels=channels, height=h, width=w, ema=None, use_ddim=args.use_ddim, ddim_steps=args.ddim_steps, title="Post-training conditional samples (EMA weights if present in checkpoint)", ) data_dir = Path(args.data_dir) try: label_mean, label_std = load_label_stats(data_dir) images_test, labels_test = load_split(data_dir, "test") n_metric = min(args.metric_num_samples, len(images_test)) idx = np.random.choice(len(images_test), n_metric, replace=False) real_slice = images_test[idx] labels_slice = labels_test[idx] labels_t = prepare_labels_for_model(labels_slice, label_mean, label_std).to(device) gen_list = [] bs = min(args.metric_batch_size, n_metric) for i in range(0, n_metric, bs): lt = labels_t[i : i + bs] with torch.no_grad(): g = eval_model.sample( labels=lt, channels=channels, height=h, width=w, device=device, progress=False, use_ddim=args.use_ddim, ddim_steps=args.ddim_steps, eta=0.0, ) gen_list.append(from_model_output(g)) gen_np = np.concatenate(gen_list, axis=0) metrics = _final_metrics_log(real_slice, gen_np, log) with open(Path(output_dir) / "final_metrics.json", "w", encoding="utf-8") as f: json.dump( { "best_val_loss": best_val_loss, "checkpoint": str(best_ckpt), **{k: v for k, v in metrics.items() if isinstance(v, (int, float))}, }, f, indent=2, ) except FileNotFoundError as e: log.warning("Skipping final PDF/P(k) metrics (data not found): %s", e) summary_path = Path(output_dir) / "run_summary.txt" with open(summary_path, "w", encoding="utf-8") as f: f.write(f"output_dir: {output_dir}\n") f.write(f"best_val_loss: {best_val_loss}\n") f.write(f"best_checkpoint: {best_ckpt}\n") f.write(f"generated_grid: {grid_path}\n") log.info("Wrote run summary to %s", summary_path) if getattr(args, "run_vlb_after_train", False): vlb_dir = Path(output_dir) / args.vlb_output_subdir log.info("Running Mudur et al. VLB grid inference (post-train) -> %s", vlb_dir) run_vlb_parameter_inference( args, log, output_dir=vlb_dir, checkpoint_path=str(best_ckpt), training_args_path=str(args_json), ) return output_dir def run_inference(args: argparse.Namespace, log: logging.Logger) -> None: device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) log.info("Device: %s", device) checkpoint_path = args.checkpoint training_args_path = args.training_args if training_args_path is None or not os.path.isfile(training_args_path): candidates = list(Path(".").rglob("args.json")) + list(Path(".").rglob("args.txt")) if not candidates: raise FileNotFoundError( "Provide --training_args pointing to args.json (or args.txt) from a training run." ) training_args_path = str(max(candidates, key=lambda p: p.stat().st_mtime)) log.info("Auto-selected training args: %s", training_args_path) if checkpoint_path is None or not os.path.isfile(checkpoint_path): ckpts = list(Path(".").rglob("checkpoints/best_model.pt")) if not ckpts: ckpts = list(Path(".").rglob("checkpoints/checkpoint_latest.pt")) if not ckpts: raise FileNotFoundError("Provide --checkpoint or train first (no best_model.pt found).") checkpoint_path = str(max(ckpts, key=lambda p: p.stat().st_mtime)) log.info("Auto-selected checkpoint: %s", checkpoint_path) config = load_training_config(training_args_path) model = build_model(config, device) load_checkpoint(model, checkpoint_path, device) diffusion = model.diffusion pin_memory = bool(args.pin_memory) and device.type == "cuda" _, _, test_loader = get_conditional_dataloaders( data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=pin_memory, normalize_labels=config.get("normalize_labels", True), ) _, labels_tensor = next(iter(test_loader)) h, w = _infer_spatial_size(test_loader) ch = test_loader.dataset[0][0].shape[0] if test_loader.dataset[0][0].dim() == 3 else 1 out_path = Path(args.inference_output) / "generated_samples_conditional.png" save_conditioned_sample_grid( model, diffusion, labels_tensor[: args.n_preview_samples], device, out_path, channels=ch, height=h, width=w, ema=None, use_ddim=args.use_ddim, ddim_steps=args.ddim_steps, title="Inference — conditional samples", ) log.info("Inference complete. Grid: %s", out_path) def build_argparser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="Conditional DDPM: train, sample, and VLB-based cosmo inference (Mudur et al. 2023)" ) p.add_argument( "--mode", type=str, choices=["train", "inference", "vlb"], required=True, help="train | inference (samples) | vlb (L0 grid on held-out fields, arXiv:2312.07534)", ) p.add_argument("--device", type=str, default="", help="cuda | cpu (empty = auto)") p.add_argument("--seed", type=int, default=42) # Model (matches train_conditional.py) p.add_argument("--label_dim", type=int, default=2, help="Conditioning vector dimension (e.g. Omega_m, sigma_8).") p.add_argument("--base_channels", type=int, default=64) p.add_argument("--channel_multipliers", type=int, nargs="+", default=[1, 2, 4, 8]) p.add_argument("--attention_levels", type=int, nargs="+", default=[2, 3]) p.add_argument("--dropout", type=float, default=0.1) # Diffusion p.add_argument("--timesteps", type=int, default=1500, help="Forward process length T (beta schedule discretization).") p.add_argument("--beta_start", type=float, default=1e-4) p.add_argument("--beta_end", type=float, default=0.02) p.add_argument("--schedule_type", type=str, default="linear", choices=["linear", "cosine"]) # Training p.add_argument("--epochs", type=int, default=100) p.add_argument("--batch_size", type=int, default=8) p.add_argument("--lr", type=float, default=2e-4) p.add_argument("--ema_decay", type=float, default=0.9999) p.add_argument("--num_workers", type=int, default=4) p.add_argument("--early_stop_patience", type=int, default=30) p.add_argument("--use_amp", action="store_true", default=False) p.add_argument("--pin_memory", action=argparse.BooleanOptionalAction, default=True) # Data p.add_argument("--data_dir", type=str, default="./data/params_2") p.add_argument("--normalize_labels", action=argparse.BooleanOptionalAction, default=True) # Output / checkpointing p.add_argument("--output_dir", type=str, default="outputs_conditional") p.add_argument("--resume", type=str, default="") p.add_argument("--sample_every", type=int, default=10) p.add_argument("--use_ddim", action=argparse.BooleanOptionalAction, default=True) p.add_argument("--ddim_steps", type=int, default=50) p.add_argument("--n_preview_samples", type=int, default=8, help="Grid size for conditional previews.") p.add_argument("--metric_num_samples", type=int, default=64, help="Samples for post-train PDF/P(k) metrics.") p.add_argument("--metric_batch_size", type=int, default=8) # Inference-only p.add_argument("--checkpoint", type=str, default=None) p.add_argument("--training_args", type=str, default=None, help="Path to args.json or args.txt from a train run.") p.add_argument("--inference_output", type=str, default="inference_outputs", help="Directory for inference artifacts.") # Mudur et al. (2023) VLB / L0 grid inference (also usable after training) p.add_argument( "--run_vlb_after_train", action="store_true", help="After training, run L0 grid parameter inference on held-out fields (writes under vlb_output_subdir).", ) p.add_argument("--vlb_output_subdir", type=str, default="vlb_posterior", help="Subfolder under training output_dir for VLB plots.") p.add_argument("--vlb_output_dir", type=str, default="vlb_inference_out", help="Output directory when --mode vlb.") p.add_argument("--vlb_split", type=str, default="test", choices=["train", "val", "test"]) p.add_argument("--vlb_n_fields", type=int, default=4, help="Number of random fields to evaluate.") p.add_argument("--vlb_n_grid", type=int, default=32, help="Grid points per parameter (paper uses 50; smaller is faster).") p.add_argument( "--vlb_half_width", type=float, default=0.1, help="Half-width of grid in each physical parameter (paper: ±0.1 clipped to CAMELS priors).", ) p.add_argument( "--vlb_prior_omega_m", type=float, nargs=2, default=[0.1, 0.5], metavar=("LO", "HI"), help="Prior range for Omega_m (physical units, matches Mudur et al. CMD priors).", ) p.add_argument( "--vlb_prior_sigma8", type=float, nargs=2, default=[0.6, 1.0], metavar=("LO", "HI"), help="Prior range for sigma_8 (physical units).", ) p.add_argument("--vlb_l0_seeds", type=int, default=3, help="MC seeds for x1 ~ q(x1|x0) in L0 (cosmic variance proxy).") p.add_argument("--vlb_chunk_size", type=int, default=32, help="Batch size for grid points on GPU.") return p def main() -> None: parser = build_argparser() args = parser.parse_args() set_seed(args.seed) log = _setup_logging() log.info("parameter_inference_conditional.py | mode=%s", args.mode) if args.mode == "train": run_training(args, log) elif args.mode == "inference": os.makedirs(args.inference_output, exist_ok=True) run_inference(args, log) else: os.makedirs(args.vlb_output_dir, exist_ok=True) run_vlb_parameter_inference(args, log) if __name__ == "__main__": main()