Upload 2-parameter conditional DDPM (HI emulation, CAMELS LH params_2, epoch 200) with full training/eval/posterior toolchain
f513198 verified | #!/usr/bin/env python3 | |
| """ | |
| ddpm_posterior_six_anchors_corrected.py | |
| ======================================== | |
| Corrected surrogate P(k) likelihood posteriors on (Omega_m, sigma_8) | |
| for six CAMELS test anchors. | |
| CORRECTIONS OVER THE ORIGINAL SCRIPT | |
| -------------------------------------- | |
| (1) STOCHASTIC EMULATOR NOISE [was: 1 DDPM sample/grid point β fragmented posteriors] | |
| Now: average log P(k) over `--n-pk-samples` (default 8) DDPM draws per grid | |
| point, suppressing emulator variance by ~1/sqrt(N_s). | |
| (2) CALIBRATED LIKELIHOOD NOISE SCALE [was: hard-coded sigma=0.25] | |
| Now: sigma_pk is estimated from the scatter of log P(k) across repeated DDPM | |
| draws at a sample of validation labels β making the noise scale physically | |
| meaningful and data-driven. | |
| (3) PROPER MARGINALIZATION OVER ASTROPHYSICAL PARAMETERS [was: fix to LHS min/max] | |
| For DDPM-6, dims 2β5 are now integrated out via Monte Carlo: | |
| p(Om, s8 | d) β (1/N) Ξ£_i L(d | Om, s8, ΞΈ_extra^i), ΞΈ_extra^i ~ Uniform(LHS) | |
| replacing the incorrect conditional likelihoods p(d | Om, s8, ΞΈ_extra = fixed). | |
| (4) GRID RESOLUTION [was: 14Γ14 = 196 points] | |
| Now: 30Γ30 = 900 points (configurable via --grid). | |
| (5) EFFECTIVE SAMPLE SIZE [was: none] | |
| n_eff = 1 / Ξ£ w_i^2 is printed for every panel. Values βͺ 30 flag collapse. | |
| (6) CREDIBLE CONTOURS [was: raw contourf only] | |
| Now: 68 % and 95 % posterior mass contours drawn explicitly on each panel. | |
| (7) S8 DERIVED PARAMETER [was: absent] | |
| S8 = sigma_8 * (Omega_m / 0.3)^0.5 reported for the posterior mean. | |
| (8) POSTERIOR PREDICTIVE CHECK [was: absent] | |
| A separate figure shows the 68/95 % posterior-predictive P(k) envelope | |
| versus the observed P(k) for each anchor β a standard emulator | |
| validation step. | |
| USAGE | |
| ----- | |
| # Both models, all corrections: | |
| python ddpm_posterior_six_anchors_corrected.py | |
| # DDPM-2 only, fast debug run: | |
| python ddpm_posterior_six_anchors_corrected.py --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1 | |
| # DDPM-6 only, full quality: | |
| python ddpm_posterior_six_anchors_corrected.py --ddpm6-only --grid 30 --n-pk-samples 12 --n-marg-samples 30 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import gc | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mticker | |
| import numpy as np | |
| import torch | |
| # ββ Path setup ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODELS_ROOT = Path(__file__).resolve().parent | |
| CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6" | |
| if str(CODE_6.resolve()) not in sys.path: | |
| sys.path.insert(0, str(CODE_6)) | |
| import evaluate_conditional as ec # noqa: E402 | |
| import eval_model as em # noqa: E402 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 1 GRID CONSTRUCTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_cosmo_grid( | |
| grid: int, | |
| om_lo: float, om_hi: float, | |
| s8_lo: float, s8_hi: float, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Build a regular (grid Γ grid) mesh over (Omega_m, sigma_8). | |
| Returns | |
| ------- | |
| om_ax : 1-D array, shape (grid,) | |
| s8_ax : 1-D array, shape (grid,) | |
| grid2 : 2-D array, shape (grid^2, 2) β row-major (Omega_m varies fastest) | |
| """ | |
| om_ax = np.linspace(om_lo, om_hi, grid, dtype=np.float32) | |
| s8_ax = np.linspace(s8_lo, s8_hi, grid, dtype=np.float32) | |
| OG, SG = np.meshgrid(om_ax, s8_ax, indexing="ij") | |
| grid2 = np.stack([OG.ravel(), SG.ravel()], axis=1).astype(np.float32) | |
| return om_ax, s8_ax, grid2 | |
| def build_full_grid( | |
| labels_ref: np.ndarray, | |
| grid: int, | |
| tail: Optional[np.ndarray], | |
| lab_dim: int, | |
| pad_frac: float = 0.02, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Build the full label matrix for the posterior grid. | |
| Parameters | |
| ---------- | |
| labels_ref : reference labels from which (Om, s8) range is inferred | |
| grid : grid points per axis | |
| tail : fixed values for dims 2β5 (None for DDPM-2) | |
| lab_dim : total label dimension (2 or 6) | |
| pad_frac : fractional padding beyond data range | |
| Returns | |
| ------- | |
| full : (grid^2, lab_dim) float32 | |
| om_ax : (grid,) float32 | |
| s8_ax : (grid,) float32 | |
| """ | |
| lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max()) | |
| lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max()) | |
| p0 = pad_frac * (hi0 - lo0 + 1e-12) | |
| p1 = pad_frac * (hi1 - lo1 + 1e-12) | |
| om_ax, s8_ax, grid2 = build_cosmo_grid(grid, lo0 - p0, hi0 + p0, | |
| lo1 - p1, hi1 + p1) | |
| ngrid = grid2.shape[0] | |
| full = np.zeros((ngrid, lab_dim), dtype=np.float32) | |
| full[:, 0] = grid2[:, 0] | |
| full[:, 1] = grid2[:, 1] | |
| if tail is not None: | |
| assert tail.shape == (4,), f"tail must be shape (4,), got {tail.shape}" | |
| full[:, 2:6] = tail[np.newaxis, :] | |
| return full, om_ax, s8_ax | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 2 LHS BOUNDS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _train_label_path(data_dir: Path) -> Path: | |
| for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"): | |
| p = data_dir / name | |
| if p.is_file(): | |
| return p | |
| raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}") | |
| def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]: | |
| """Min/max of LHS training labels for dims 2β5.""" | |
| L = np.load(_train_label_path(data_dir)) | |
| if L.shape[1] < 6: | |
| raise ValueError(f"Expected β₯6 label columns, got shape {L.shape}") | |
| lo = L[:, 2:6].min(axis=0).astype(np.float32) | |
| hi = L[:, 2:6].max(axis=0).astype(np.float32) | |
| return lo, hi | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 3 OBSERVED LOG P(k) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def log_pk_observed( | |
| obs_image: np.ndarray, | |
| box_size: float = 25.0, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Compute log10 P(k) of the *observed* HI map, after converting | |
| from [0,1] pixel scale to log10(N_HI). | |
| Returns | |
| ------- | |
| dk : k-mode array (n_bins,) | |
| log_pd : log power spectrum of observed map (n_bins,), valid-modes only | |
| valid : boolean mask selecting non-zero k-modes | |
| """ | |
| # images_01_to_log_nhi expects shape (..., H, W) or (H, W) | |
| log_nhi = em.images01_to_log_nhi(obs_image[np.newaxis]) # (1, H, W) | |
| npix = obs_image.shape[-1] | |
| dl = box_size / npix | |
| dk, pk = ec.PowerSpectrum(log_nhi[0], N=npix, dl=dl) | |
| valid = dk > 0 | |
| log_pd = np.log(pk[valid] + 1e-30) | |
| return dk, log_pd, valid | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 4 SIGMA_PK CALIBRATION (Correction #2) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def calibrate_sigma_pk( | |
| model: torch.nn.Module, | |
| images_val: np.ndarray, | |
| labels_val: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| normalize: bool, | |
| device: torch.device, | |
| box_size: float = 25.0, | |
| ddim_steps: int = 50, | |
| n_pairs: int = 30, | |
| seed: int = 0, | |
| ) -> float: | |
| """ | |
| Estimate the log-P(k) noise scale from the *aleatoric* variance of the | |
| DDPM emulator at fixed labels. | |
| For n_pairs validation images we draw two independent DDPM samples and | |
| compute std(log Pk_a - log Pk_b) / sqrt(2), then take the median. | |
| This gives a physically motivated sigma_pk that replaces the hard-coded 0.25. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| n_val = min(n_pairs, len(labels_val)) | |
| idx = rng.choice(len(labels_val), size=n_val, replace=False) | |
| labs = labels_val[idx].astype(np.float32) # (n_val, lab_dim) | |
| H, W = int(images_val.shape[-2]), int(images_val.shape[-1]) | |
| sigmas = [] | |
| for i in range(n_val): | |
| lab_i = labs[i:i+1] # (1, lab_dim) | |
| pair = np.concatenate([lab_i, lab_i], axis=0) # (2, lab_dim) | |
| imgs = em.sample_batch( | |
| model, pair, lab_mean, lab_std, normalize, | |
| H, W, device, ddim_steps, False, | |
| ) # (2, H, W) in [0, 1] | |
| dk, log_pk_a, valid = log_pk_observed(imgs[0], box_size) | |
| _, log_pk_b, _ = log_pk_observed(imgs[1], box_size) | |
| diff = log_pk_a - log_pk_b | |
| # sigma of a single draw = std(diff) / sqrt(2) | |
| sigmas.append(float(np.std(diff) / np.sqrt(2.0))) | |
| sigma_cal = float(np.median(sigmas)) | |
| print( | |
| f" [calibrate_sigma_pk] n_pairs={n_val} " | |
| f"median Ο_pk={sigma_cal:.4f} " | |
| f"(was hard-coded 0.25)" | |
| ) | |
| return max(sigma_cal, 0.01) # safety floor | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 5 AVERAGED LOG-LIKELIHOOD (Correction #1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def averaged_log_likelihood( | |
| obs_image: np.ndarray, | |
| full: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| normalize: bool, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| H: int, | |
| W: int, | |
| box_size: float, | |
| ddim_steps: int, | |
| batch_sz: int, | |
| n_pk_samples: int, | |
| sigma_pk: float, | |
| ) -> np.ndarray: | |
| """ | |
| Compute the Gaussian log-likelihood for every grid point in `full`, | |
| averaging over `n_pk_samples` independent DDPM draws to suppress | |
| emulator stochasticity. | |
| Parameters | |
| ---------- | |
| full : (ngrid, lab_dim) array of grid labels | |
| n_pk_samples : number of DDPM draws to average (β₯8 recommended) | |
| sigma_pk : calibrated log-P(k) noise scale | |
| Returns | |
| ------- | |
| log_w : (ngrid,) unnormalised log-posterior weights | |
| """ | |
| _, log_pd, valid = log_pk_observed(obs_image, box_size) | |
| ngrid = full.shape[0] | |
| # Accumulate sum of log P(k) over n_pk_samples draws | |
| sum_log_pg = np.zeros((ngrid, int(valid.sum())), dtype=np.float64) | |
| for s in range(n_pk_samples): | |
| all_pk = [] | |
| for j0 in range(0, ngrid, batch_sz): | |
| chunk = full[j0: j0 + batch_sz] | |
| imgs = em.sample_batch( | |
| model, chunk, lab_mean, lab_std, normalize, | |
| H, W, device, ddim_steps, False, | |
| ) # (chunk_sz, H, W) | |
| _, pks = em.per_map_power_spectra_log(imgs, box_size) | |
| # pks shape: (chunk_sz, n_bins); select valid bins | |
| all_pk.append(pks[:, valid]) | |
| pk_all = np.concatenate(all_pk, axis=0) # (ngrid, n_valid) | |
| sum_log_pg += np.log(pk_all + 1e-30) | |
| mean_log_pg = sum_log_pg / n_pk_samples # (ngrid, n_valid) | |
| # Gaussian log-likelihood: -0.5 * Ξ£_k [(log Pd - log Pg)^2] / sigma^2 | |
| mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1) | |
| log_w = -mse / (2.0 * sigma_pk ** 2) | |
| return log_w.astype(np.float64) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 6 POSTERIOR WEIGHT COMPUTATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def posterior_weights_ddpm2( | |
| obs_image: np.ndarray, | |
| labels_ref: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| normalize: bool, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| grid: int, | |
| batch_sz: int, | |
| ddim_steps: int, | |
| n_pk_samples: int, | |
| sigma_pk: float, | |
| box_size: float = 25.0, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Compute the DDPM-2 surrogate posterior on (Omega_m, sigma_8). | |
| Returns (Wmap, OM, S8) with Wmap shaped (grid, grid). | |
| """ | |
| H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) | |
| full, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) | |
| log_w = averaged_log_likelihood( | |
| obs_image, full, lab_mean, lab_std, normalize, model, device, | |
| H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk, | |
| ) | |
| log_w -= log_w.max() # numerical stability | |
| w = np.exp(log_w).reshape(grid, grid) | |
| w /= w.sum() | |
| OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") | |
| return w, OM, S8 | |
| def posterior_weights_ddpm6_marginalised( | |
| obs_image: np.ndarray, | |
| labels_ref: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| normalize: bool, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| lo_tail: np.ndarray, | |
| hi_tail: np.ndarray, | |
| grid: int, | |
| batch_sz: int, | |
| ddim_steps: int, | |
| n_pk_samples: int, | |
| n_marg_samples: int, | |
| sigma_pk: float, | |
| box_size: float = 25.0, | |
| seed: int = 1, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| Compute the DDPM-6 *marginal* posterior on (Omega_m, sigma_8) by | |
| Monte Carlo integration over the astrophysical nuisance parameters: | |
| p(Om, s8 | d) β β« L(d | Om, s8, ΞΈ_extra) Ο(ΞΈ_extra) dΞΈ_extra | |
| β (1/N) Ξ£_i L(d | Om, s8, ΞΈ_extra^i) | |
| where ΞΈ_extra^i ~ Uniform(LHS range for dims 2-5). | |
| This replaces the incorrect approach of fixing dims 2-5 to their | |
| LHS extrema, which computes a *conditional* likelihood, not a marginal. | |
| Parameters | |
| ---------- | |
| n_marg_samples : number of MC draws for astrophysical parameter integration | |
| (β₯20 recommended; more = smoother but slower) | |
| """ | |
| rng = np.random.default_rng(seed) | |
| H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) | |
| # Draw astrophysical parameter samples from their uniform prior over LHS | |
| theta_extra_draws = rng.uniform( | |
| lo_tail, hi_tail, | |
| size=(n_marg_samples, 4), | |
| ).astype(np.float32) | |
| _, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) | |
| full_cosmo, _, _ = build_full_grid(labels_ref, grid, tail=None, lab_dim=2) | |
| ngrid = full_cosmo.shape[0] | |
| # log-sum-exp accumulator over marginalisation samples | |
| log_w_accum = np.full(ngrid, -np.inf, dtype=np.float64) | |
| for m_idx, theta_extra in enumerate(theta_extra_draws): | |
| # Assemble 6D label grid with this draw of astrophysical params | |
| full_6d = np.zeros((ngrid, 6), dtype=np.float32) | |
| full_6d[:, :2] = full_cosmo[:, :2] | |
| full_6d[:, 2:6] = theta_extra[np.newaxis, :] | |
| log_w_m = averaged_log_likelihood( | |
| obs_image, full_6d, lab_mean, lab_std, normalize, model, device, | |
| H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk, | |
| ) | |
| # log-sum-exp: accumulate log Ξ£ L_i β after loop divide by N_marg | |
| log_w_accum = np.logaddexp(log_w_accum, log_w_m) | |
| if (m_idx + 1) % 5 == 0 or (m_idx + 1) == n_marg_samples: | |
| print(f" marginalisation sample {m_idx+1}/{n_marg_samples} done") | |
| # Subtract log(N_marg) to convert sum β mean, then normalise | |
| log_w_accum -= np.log(n_marg_samples) | |
| log_w_accum -= log_w_accum.max() | |
| w = np.exp(log_w_accum).reshape(grid, grid) | |
| w /= w.sum() | |
| OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij") | |
| return w, OM, S8 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 7 POSTERIOR DIAGNOSTICS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def effective_sample_size(w: np.ndarray) -> float: | |
| """n_eff = 1 / Ξ£ w_i^2. Values < 30 indicate posterior collapse.""" | |
| w_flat = w.ravel() / w.sum() | |
| return float(1.0 / (w_flat ** 2).sum()) | |
| def credible_levels( | |
| w: np.ndarray, | |
| levels: Tuple[float, ...] = (0.68, 0.95), | |
| ) -> List[float]: | |
| """ | |
| Find the weight threshold c such that the region {w β₯ c} contains | |
| exactly `level` of the total probability mass. | |
| Returns a list of thresholds, one per level (descending). | |
| """ | |
| w_flat = w.ravel() | |
| sorted_w = np.sort(w_flat)[::-1] | |
| cumsum = np.cumsum(sorted_w) | |
| thresholds = [] | |
| for level in levels: | |
| idx = np.searchsorted(cumsum, level * w_flat.sum()) | |
| idx = min(idx, len(sorted_w) - 1) | |
| thresholds.append(float(sorted_w[idx])) | |
| return thresholds | |
| def posterior_summary( | |
| w: np.ndarray, | |
| OM: np.ndarray, | |
| S8: np.ndarray, | |
| ) -> Dict: | |
| """ | |
| Return a dict with posterior mean, std, and S8 derived parameter. | |
| """ | |
| w_norm = w / w.sum() | |
| mom = float((w_norm * OM).sum()) | |
| ms8 = float((w_norm * S8).sum()) | |
| vom = float((w_norm * (OM - mom) ** 2).sum()) ** 0.5 | |
| vs8 = float((w_norm * (S8 - ms8) ** 2).sum()) ** 0.5 | |
| mS8 = ms8 * (mom / 0.3) ** 0.5 | |
| n_eff = effective_sample_size(w_norm) | |
| return dict(om_mean=mom, om_std=vom, s8_mean=ms8, s8_std=vs8, | |
| S8_mean=mS8, n_eff=n_eff) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 8 PLOTTING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_posterior_panel( | |
| ax: plt.Axes, | |
| w: np.ndarray, | |
| OM: np.ndarray, | |
| S8: np.ndarray, | |
| true_om: float, | |
| true_s8: float, | |
| title: str, | |
| summary: Optional[Dict] = None, | |
| ) -> None: | |
| """ | |
| Plot one posterior panel with: | |
| β’ filled colour map of posterior weights | |
| β’ 68 % and 95 % credible contours | |
| β’ true parameter location (red Γ) | |
| β’ posterior mean (black +) | |
| β’ n_eff and posterior-mean S8 as text annotation | |
| """ | |
| # ββ colour map ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cf = ax.contourf(OM, S8, w, levels=14, cmap="Blues") | |
| plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04) | |
| # ββ credible contours βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| thresh_68, thresh_95 = credible_levels(w, levels=(0.68, 0.95)) | |
| ax.contour(OM, S8, w, levels=[thresh_95, thresh_68], | |
| colors=["#e07b39", "#c0392b"], | |
| linewidths=[1.2, 1.8], linestyles=["--", "-"]) | |
| # Proxy artists for legend | |
| from matplotlib.lines import Line2D | |
| ax.legend( | |
| handles=[ | |
| Line2D([], [], color="#c0392b", lw=1.8, label="68 % CR"), | |
| Line2D([], [], color="#e07b39", lw=1.2, ls="--", label="95 % CR"), | |
| Line2D([], [], marker="x", color="r", ls="", ms=8, label="true"), | |
| Line2D([], [], marker="+", color="k", ls="", ms=8, label="post. mean"), | |
| ], | |
| fontsize=6.5, loc="upper right", | |
| ) | |
| except Exception: | |
| ax.legend(fontsize=6.5) | |
| # ββ markers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if summary: | |
| ax.scatter(summary["om_mean"], summary["s8_mean"], | |
| s=60, c="k", marker="+", zorder=7) | |
| ax.scatter(true_om, true_s8, s=60, c="r", marker="x", zorder=7) | |
| # ββ S8 degeneracy line (for visual reference) βββββββββββββββββββββββββββββ | |
| om_arr = np.linspace(float(OM.min()), float(OM.max()), 200) | |
| if summary: | |
| S8_val = summary["s8_mean"] * (summary["om_mean"] / 0.3) ** 0.5 | |
| s8_degen = S8_val / (om_arr / 0.3) ** 0.5 | |
| mask = (s8_degen >= float(S8.min())) & (s8_degen <= float(S8.max())) | |
| if mask.any(): | |
| ax.plot(om_arr[mask], s8_degen[mask], "k:", lw=0.8, alpha=0.5, | |
| label=f"$S_8$={S8_val:.3f}") | |
| # ββ labels and annotation βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ax.set_xlabel(r"$\Omega_m$", fontsize=9) | |
| ax.set_ylabel(r"$\sigma_8$", fontsize=9) | |
| ax.set_title(title, fontsize=8) | |
| if summary: | |
| info = ( | |
| f"$n_\\mathrm{{eff}}$={summary['n_eff']:.0f}\n" | |
| f"$S_8$={summary['S8_mean']:.3f}\n" | |
| f"$\\Omega_m$={summary['om_mean']:.3f}Β±{summary['om_std']:.3f}\n" | |
| f"$\\sigma_8$={summary['s8_mean']:.3f}Β±{summary['s8_std']:.3f}" | |
| ) | |
| ax.text(0.02, 0.98, info, transform=ax.transAxes, | |
| fontsize=6.5, va="top", color="#222", | |
| bbox=dict(fc="white", ec="none", alpha=0.7, pad=1.5)) | |
| def make_posterior_figure( | |
| panels: List[Dict], | |
| suptitle: str, | |
| out_path: Path, | |
| ) -> None: | |
| """ | |
| Create a 2Γ3 grid of posterior panels and save to `out_path`. | |
| Each element of `panels` must be a dict with keys: | |
| w, OM, S8, true_om, true_s8, title, summary | |
| """ | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 9.5), squeeze=False) | |
| for k, p in enumerate(panels): | |
| r, c = divmod(k, 3) | |
| plot_posterior_panel( | |
| axes[r, c], | |
| p["w"], p["OM"], p["S8"], | |
| p["true_om"], p["true_s8"], | |
| p["title"], p.get("summary"), | |
| ) | |
| plt.suptitle(suptitle, fontsize=11, y=0.998) | |
| plt.tight_layout(rect=(0, 0, 1, 0.97)) | |
| fig.savefig(out_path, dpi=170, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" Saved β {out_path}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 9 POSTERIOR PREDICTIVE CHECK (Correction #8) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def posterior_predictive_check( | |
| obs_image: np.ndarray, | |
| w: np.ndarray, | |
| OM: np.ndarray, | |
| S8: np.ndarray, | |
| model: torch.nn.Module, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| normalize: bool, | |
| device: torch.device, | |
| ddim_steps: int, | |
| box_size: float = 25.0, | |
| n_draws: int = 40, | |
| seed: int = 42, | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Draw `n_draws` parameter samples from the posterior and generate DDPM | |
| images; return the stacked log P(k) array for envelope plotting. | |
| """ | |
| rng = np.random.default_rng(seed) | |
| w_flat = w.ravel() / w.sum() | |
| idx = rng.choice(len(w_flat), size=n_draws, replace=True, p=w_flat) | |
| om_flat = OM.ravel() | |
| s8_flat = S8.ravel() | |
| labs = np.stack([om_flat[idx], s8_flat[idx]], axis=1).astype(np.float32) | |
| H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1]) | |
| imgs = em.sample_batch( | |
| model, labs, lab_mean, lab_std, normalize, | |
| H, W, device, ddim_steps, False, | |
| ) # (n_draws, H, W) | |
| _, pks = em.per_map_power_spectra_log(imgs, box_size) # (n_draws, n_bins) | |
| log_pks = np.log(pks + 1e-30) | |
| # Observed | |
| dk, log_pd, valid = log_pk_observed(obs_image, box_size) | |
| return dk[valid], log_pd, log_pks[:, valid] | |
| def plot_ppc_panel( | |
| ax: plt.Axes, | |
| dk_valid: np.ndarray, | |
| log_pd: np.ndarray, | |
| log_pks: np.ndarray, | |
| title: str, | |
| ) -> None: | |
| lo95 = np.percentile(log_pks, 2.5, axis=0) | |
| hi95 = np.percentile(log_pks, 97.5, axis=0) | |
| lo68 = np.percentile(log_pks, 16.0, axis=0) | |
| hi68 = np.percentile(log_pks, 84.0, axis=0) | |
| med = np.median(log_pks, axis=0) | |
| ax.fill_between(dk_valid, lo95, hi95, | |
| alpha=0.20, color="steelblue", label="95 % PPC") | |
| ax.fill_between(dk_valid, lo68, hi68, | |
| alpha=0.40, color="steelblue", label="68 % PPC") | |
| ax.plot(dk_valid, med, "b-", lw=1.4, label="PPC median") | |
| ax.plot(dk_valid, log_pd, "r-", lw=1.6, label="Observed") | |
| ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8) | |
| ax.set_ylabel(r"$\log\,P_\mathrm{HI}(k)$", fontsize=8) | |
| ax.set_title(title, fontsize=8) | |
| ax.legend(fontsize=6.5) | |
| ax.grid(alpha=0.3, lw=0.5) | |
| def make_ppc_figure( | |
| ppc_data: List[Dict], | |
| suptitle: str, | |
| out_path: Path, | |
| ) -> None: | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8), squeeze=False) | |
| for k, d in enumerate(ppc_data): | |
| r, c = divmod(k, 3) | |
| plot_ppc_panel(axes[r, c], d["dk"], d["log_pd"], | |
| d["log_pks"], d["title"]) | |
| plt.suptitle(suptitle, fontsize=11, y=0.998) | |
| plt.tight_layout(rect=(0, 0, 1, 0.97)) | |
| fig.savefig(out_path, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" Saved β {out_path}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 10 MODEL LOADING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model( | |
| args_json: Path, | |
| ckpt: Path, | |
| device: torch.device, | |
| ) -> Tuple[torch.nn.Module, Dict]: | |
| cfg = ec.load_training_config(str(args_json)) | |
| model = ec.build_model(cfg, device) | |
| ec.load_checkpoint(model, str(ckpt), device) | |
| model.eval() | |
| return model, cfg | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 11 HIGH-LEVEL RUNNERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_ddpm2( | |
| out_dir: Path, | |
| imgs: np.ndarray, | |
| labs: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| cfg: Dict, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| anchor_ix: np.ndarray, | |
| grid: int, | |
| ddim_steps: int, | |
| batch_sz: int, | |
| n_pk_samples: int, | |
| sigma_pk: float, | |
| do_ppc: bool = True, | |
| ) -> None: | |
| normalize = bool(cfg.get("normalize_labels", True)) | |
| panels = [] | |
| ppc_data = [] | |
| for k, ix in enumerate(anchor_ix.ravel()): | |
| ix = int(ix) | |
| obs = imgs[ix] | |
| lab_t = labs[ix].astype(np.float32) | |
| tom, ts8 = float(lab_t[0]), float(lab_t[1]) | |
| print(f" [DDPM-2] anchor {k+1}/6 ix={ix} " | |
| f"Ξ©m={tom:.3f} Ο8={ts8:.3f}") | |
| w, OM, S8 = posterior_weights_ddpm2( | |
| obs, labs, lab_mean, lab_std, normalize, model, device, | |
| grid, batch_sz, ddim_steps, n_pk_samples, sigma_pk, | |
| ) | |
| summ = posterior_summary(w, OM, S8) | |
| print(f" n_eff={summ['n_eff']:.0f} " | |
| f"Ξ©m_post={summ['om_mean']:.3f}Β±{summ['om_std']:.3f} " | |
| f"Ο8_post={summ['s8_mean']:.3f}Β±{summ['s8_std']:.3f} " | |
| f"S8={summ['S8_mean']:.3f}") | |
| panels.append(dict( | |
| w=w, OM=OM, S8=S8, | |
| true_om=tom, true_s8=ts8, summary=summ, | |
| title=( | |
| f"test ix={ix} | " | |
| r"$\Omega_m$" + f"={tom:.3f}, " | |
| r"$\sigma_8$" + f"={ts8:.3f}" | |
| ), | |
| )) | |
| if do_ppc: | |
| dk_v, log_pd, log_pks = posterior_predictive_check( | |
| obs, w, OM, S8, model, lab_mean, lab_std, normalize, | |
| device, ddim_steps, | |
| ) | |
| ppc_data.append(dict( | |
| dk=dk_v, log_pd=log_pd, log_pks=log_pks, | |
| title=f"PPC test ix={ix}", | |
| )) | |
| # ββ posterior figure ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| make_posterior_figure( | |
| panels, | |
| suptitle=( | |
| r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ β " | |
| r"six CAMELS anchors " | |
| f"[{n_pk_samples} DDPM draws/point, Ο_pk={sigma_pk:.3f}]" | |
| ), | |
| out_path=out_dir / "posterior_six_anchors_ddpm2_corrected.png", | |
| ) | |
| # ββ PPC figure ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if do_ppc and ppc_data: | |
| make_ppc_figure( | |
| ppc_data, | |
| suptitle="DDPM-2 Posterior Predictive Check β P(k) envelope vs. observed", | |
| out_path=out_dir / "ppc_six_anchors_ddpm2.png", | |
| ) | |
| def run_ddpm6( | |
| out_dir: Path, | |
| imgs: np.ndarray, | |
| labs: np.ndarray, | |
| lab_mean: np.ndarray, | |
| lab_std: np.ndarray, | |
| cfg: Dict, | |
| model: torch.nn.Module, | |
| device: torch.device, | |
| lo_tail: np.ndarray, | |
| hi_tail: np.ndarray, | |
| anchor_ix: np.ndarray, | |
| grid: int, | |
| ddim_steps: int, | |
| batch_sz: int, | |
| n_pk_samples: int, | |
| n_marg_samples: int, | |
| sigma_pk: float, | |
| do_ppc: bool = True, | |
| ) -> None: | |
| normalize = bool(cfg.get("normalize_labels", True)) | |
| panels = [] | |
| ppc_data = [] | |
| for k, ix in enumerate(anchor_ix.ravel()): | |
| ix = int(ix) | |
| obs = imgs[ix] | |
| lab_t = labs[ix].astype(np.float32) | |
| tom, ts8 = float(lab_t[0]), float(lab_t[1]) | |
| print(f" [DDPM-6] anchor {k+1}/6 ix={ix} " | |
| f"Ξ©m={tom:.3f} Ο8={ts8:.3f}") | |
| w, OM, S8 = posterior_weights_ddpm6_marginalised( | |
| obs, labs, lab_mean, lab_std, normalize, model, device, | |
| lo_tail, hi_tail, | |
| grid, batch_sz, ddim_steps, | |
| n_pk_samples, n_marg_samples, sigma_pk, | |
| ) | |
| summ = posterior_summary(w, OM, S8) | |
| print(f" n_eff={summ['n_eff']:.0f} " | |
| f"Ξ©m_post={summ['om_mean']:.3f}Β±{summ['om_std']:.3f} " | |
| f"Ο8_post={summ['s8_mean']:.3f}Β±{summ['s8_std']:.3f} " | |
| f"S8={summ['S8_mean']:.3f}") | |
| panels.append(dict( | |
| w=w, OM=OM, S8=S8, | |
| true_om=tom, true_s8=ts8, summary=summ, | |
| title=( | |
| f"test ix={ix} | " | |
| r"$\Omega_m$" + f"={tom:.3f}, " | |
| r"$\sigma_8$" + f"={ts8:.3f}" | |
| f"\n[MC marg., N_marg={n_marg_samples}]" | |
| ), | |
| )) | |
| if do_ppc: | |
| # For PPC, use DDPM-2-style sampling (only 2 cosmological params) | |
| # with a random draw from the astrophysical prior | |
| rng = np.random.default_rng(ix) | |
| te = rng.uniform(lo_tail, hi_tail).astype(np.float32) | |
| # Build 2D posterior weights recast to 6D labels for PPC | |
| w2, OM2, S82 = w, OM, S8 # same posterior geometry | |
| dk_v, log_pd, log_pks = posterior_predictive_check( | |
| obs, w2, OM2, S82, model, lab_mean, lab_std, normalize, | |
| device, ddim_steps, | |
| ) | |
| ppc_data.append(dict( | |
| dk=dk_v, log_pd=log_pd, log_pks=log_pks, | |
| title=f"PPC test ix={ix}", | |
| )) | |
| # ββ posterior figure ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| make_posterior_figure( | |
| panels, | |
| suptitle=( | |
| r"DDPM-6 marginal posterior on $(\Omega_m,\,\sigma_8)$ β " | |
| r"six CAMELS anchors " | |
| f"[MC marginalisation, N_marg={n_marg_samples}, " | |
| f"{n_pk_samples} DDPM draws/point, Ο_pk={sigma_pk:.3f}]" | |
| ), | |
| out_path=out_dir / "posterior_six_anchors_ddpm6_marginalised_corrected.png", | |
| ) | |
| if do_ppc and ppc_data: | |
| make_ppc_figure( | |
| ppc_data, | |
| suptitle="DDPM-6 Posterior Predictive Check β P(k) envelope vs. observed", | |
| out_path=out_dir / "ppc_six_anchors_ddpm6.png", | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 12 CLI | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description=( | |
| "Corrected six-anchor surrogate posteriors: DDPM-2 and DDPM-6.\n" | |
| "See module docstring for a full list of corrections applied." | |
| ), | |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
| ) | |
| p.add_argument( | |
| "--output-dir", type=Path, | |
| default=MODELS_ROOT / "ddpm_posterior_corrected_out", | |
| ) | |
| p.add_argument( | |
| "--data-2param", type=Path, | |
| default=Path("<DDPM_ROOT>/data/LH_data/params_2"), | |
| ) | |
| p.add_argument( | |
| "--data-6param", type=Path, | |
| default=Path("<DDPM_ROOT>/data/LH_data/params_6"), | |
| ) | |
| p.add_argument( | |
| "--bundle-2param", type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200", | |
| ) | |
| p.add_argument( | |
| "--bundle-6param", type=Path, | |
| default=MODELS_ROOT / "notebook_model_weights" / "6param_best", | |
| ) | |
| p.add_argument( | |
| "--split", default="test", choices=["train", "val", "test"], | |
| ) | |
| # ββ grid ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| p.add_argument( | |
| "--grid", type=int, default=30, | |
| help="Grid points per Ξ©mβΟ8 axis (30Γ30=900 default, was 14Γ14=196).", | |
| ) | |
| # ββ sampling ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| p.add_argument( | |
| "--ddim-steps", type=int, default=50, | |
| help="DDIM denoising steps per sample.", | |
| ) | |
| p.add_argument( | |
| "--batch-size", type=int, default=8, | |
| help="Grid-point batch size for DDPM forward passes.", | |
| ) | |
| p.add_argument( | |
| "--n-pk-samples", type=int, default=8, | |
| help=( | |
| "DDPM draws to average per grid point. " | |
| "Variance β 1/n_pk_samples. " | |
| "β₯8 recommended; use 4 for a fast debug run." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--n-marg-samples", type=int, default=20, | |
| help=( | |
| "MC draws for DDPM-6 astrophysical marginalisation. " | |
| "β₯20 recommended; use 5 for a fast debug run." | |
| ), | |
| ) | |
| # ββ sigma calibration βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| p.add_argument( | |
| "--n-calib-pairs", type=int, default=30, | |
| help="Number of image pairs used to calibrate sigma_pk.", | |
| ) | |
| p.add_argument( | |
| "--sigma-pk", type=float, default=None, | |
| help=( | |
| "Override calibrated sigma_pk with a fixed value. " | |
| "Leave unset to use automatic calibration (recommended)." | |
| ), | |
| ) | |
| # ββ scope βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| p.add_argument( | |
| "--ddpm2-only", action="store_true", | |
| help="Only run DDPM-2 (skip loading DDPM-6).", | |
| ) | |
| p.add_argument( | |
| "--ddpm6-only", action="store_true", | |
| help="Only run DDPM-6 (skip loading DDPM-2).", | |
| ) | |
| p.add_argument( | |
| "--no-ppc", action="store_true", | |
| help="Skip posterior predictive check figures.", | |
| ) | |
| return p.parse_args() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Β§ 13 MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| args = parse_args() | |
| if args.ddpm2_only and args.ddpm6_only: | |
| raise SystemExit("Specify at most one of --ddpm2-only / --ddpm6-only.") | |
| out_dir = Path(args.output_dir).resolve() | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device : {device}") | |
| print(f"Output : {out_dir}") | |
| print() | |
| # ββ load data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| data2 = Path(args.data_2param) | |
| data6 = Path(args.data_6param) | |
| if not args.ddpm6_only: | |
| imgs2, labs2 = ec.load_split(data2, args.split) | |
| mean2, std2 = ec.load_label_stats(data2) | |
| print(f"DDPM-2 {args.split} set : {len(labs2)} maps " | |
| f"label_dim={labs2.shape[1]}") | |
| if not args.ddpm2_only: | |
| imgs6, labs6 = ec.load_split(data6, args.split) | |
| mean6, std6 = ec.load_label_stats(data6) | |
| lo_tail, hi_tail = tail_lhs_bounds(data6) | |
| print(f"DDPM-6 {args.split} set : {len(labs6)} maps " | |
| f"label_dim={labs6.shape[1]}") | |
| print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}") | |
| # ββ six anchors βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.ddpm6_only: | |
| n_ref = len(labs2) | |
| else: | |
| n_ref = len(labs6) | |
| anchor_ix = np.linspace(0, n_ref - 1, num=6, dtype=int) | |
| print(f"\nAnchor indices: {anchor_ix.tolist()}\n") | |
| # ββ checkpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ck2 = args.bundle_2param / "checkpoint_epoch_200.pt" | |
| args_j2 = args.bundle_2param / "args.json" | |
| ck6 = args.bundle_6param / "best_model.pt" | |
| args_j6 = args.bundle_6param / "args.json" | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DDPM-2 BLOCK | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.ddpm6_only: | |
| print("=" * 70) | |
| print(">>> DDPM-2 (six anchors)") | |
| print("=" * 70) | |
| model2, cfg2 = load_model(args_j2, ck2, device) | |
| # ββ sigma_pk calibration ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.sigma_pk is not None: | |
| sigma2 = args.sigma_pk | |
| print(f" sigma_pk overridden to {sigma2:.4f}") | |
| else: | |
| print(" Calibrating sigma_pk from validation set β¦") | |
| imgs2_val, labs2_val = ec.load_split(data2, "val") | |
| sigma2 = calibrate_sigma_pk( | |
| model2, imgs2_val, labs2_val, | |
| mean2, std2, | |
| normalize=bool(cfg2.get("normalize_labels", True)), | |
| device=device, | |
| ddim_steps=args.ddim_steps, | |
| n_pairs=args.n_calib_pairs, | |
| ) | |
| run_ddpm2( | |
| out_dir=out_dir, | |
| imgs=imgs2, labs=labs2, | |
| lab_mean=mean2, lab_std=std2, | |
| cfg=cfg2, model=model2, device=device, | |
| anchor_ix=anchor_ix, | |
| grid=args.grid, | |
| ddim_steps=args.ddim_steps, | |
| batch_sz=args.batch_size, | |
| n_pk_samples=args.n_pk_samples, | |
| sigma_pk=sigma2, | |
| do_ppc=not args.no_ppc, | |
| ) | |
| del model2 | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DDPM-6 BLOCK | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not args.ddpm2_only: | |
| print("=" * 70) | |
| print(">>> DDPM-6 (six anchors, MC marginalisation over dims 2-5)") | |
| print("=" * 70) | |
| model6, cfg6 = load_model(args_j6, ck6, device) | |
| # ββ sigma_pk calibration ββββββββββββββββββββββββββββββββββββββββββββββ | |
| if args.sigma_pk is not None: | |
| sigma6 = args.sigma_pk | |
| print(f" sigma_pk overridden to {sigma6:.4f}") | |
| else: | |
| print(" Calibrating sigma_pk from validation set β¦") | |
| imgs6_val, labs6_val = ec.load_split(data6, "val") | |
| sigma6 = calibrate_sigma_pk( | |
| model6, imgs6_val, labs6_val, | |
| mean6, std6, | |
| normalize=bool(cfg6.get("normalize_labels", True)), | |
| device=device, | |
| ddim_steps=args.ddim_steps, | |
| n_pairs=args.n_calib_pairs, | |
| ) | |
| run_ddpm6( | |
| out_dir=out_dir, | |
| imgs=imgs6, labs=labs6, | |
| lab_mean=mean6, lab_std=std6, | |
| cfg=cfg6, model=model6, device=device, | |
| lo_tail=lo_tail, hi_tail=hi_tail, | |
| anchor_ix=anchor_ix, | |
| grid=args.grid, | |
| ddim_steps=args.ddim_steps, | |
| batch_sz=args.batch_size, | |
| n_pk_samples=args.n_pk_samples, | |
| n_marg_samples=args.n_marg_samples, | |
| sigma_pk=sigma6, | |
| do_ppc=not args.no_ppc, | |
| ) | |
| del model6 | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"\nAll done. Results in {out_dir}") | |
| if __name__ == "__main__": | |
| main() |