collins909 commited on
Commit
eb725f8
·
verified ·
1 Parent(s): c46900a

Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain

Browse files
Files changed (34) hide show
  1. README.md +45 -4
  2. cross_model/README.md +17 -0
  3. cross_model/check_poster_env.py +78 -0
  4. cross_model/compare_posterior_inference.py +699 -0
  5. cross_model/ddpm_posterior_corrected.py +867 -0
  6. cross_model/poster.py +1112 -0
  7. cross_model/run_compare_posterior.sh +52 -0
  8. cross_model/run_vlb_inference_1000grid.sh +81 -0
  9. cross_model/run_vlb_inference_200grid.sh +78 -0
  10. cross_model/scripts/compare_ddpm_models.py +855 -0
  11. cross_model/scripts/compare_ddpm_training_curves.py +45 -0
  12. cross_model/scripts/ddpm_figure6_integration.py +271 -0
  13. cross_model/scripts/ddpm_posterior_six_anchors.py +451 -0
  14. cross_model/scripts/ddpm_triangle_integration.py +194 -0
  15. cross_model/scripts/figure6_2409_style.py +157 -0
  16. cross_model/scripts/run_ddpm_comparison.sh +66 -0
  17. cross_model/scripts/run_ddpm_figure6.sh +27 -0
  18. cross_model/scripts/run_ddpm_figure6_suite.py +315 -0
  19. cross_model/scripts/run_ddpm_posterior_corrected.sh +58 -0
  20. cross_model/scripts/run_ddpm_posterior_six_anchors.sh +52 -0
  21. cross_model/scripts/run_poster.sh +53 -0
  22. cross_model/scripts/run_posterior_inference.sh +74 -0
  23. cross_model/scripts/run_triangle_ddpm_both.sh +75 -0
  24. cross_model/scripts/sigma_contour_utils.py +29 -0
  25. cross_model/scripts/triangle_plot_posterior.py +128 -0
  26. cross_model/submit_vlb_1000grid.py +106 -0
  27. scripts/shell/evaluate_conditional_lh6.sh +61 -0
  28. scripts/shell/plot_r2_cosmology_lhs.sh +72 -0
  29. scripts/shell/train_conditional_lh6.sh +60 -0
  30. src/eval_model.py +86 -0
  31. src/figure9_posterior.py +33 -0
  32. src/plot_r2_cosmology_lhs.py +316 -0
  33. src/posterior_inference.py +895 -0
  34. src/train_conditional.py +447 -0
README.md CHANGED
@@ -26,17 +26,58 @@ This is the **best-validation checkpoint** from the training run under
26
 
27
  ## Files in this repo
28
 
 
 
29
  | File | Purpose |
30
  |------|---------|
31
  | `model.pt` | PyTorch checkpoint (state-dict for `ConditionalDiffusionModel`) |
32
  | `args.json` / `args.txt` | Training hyper-parameters and U-Net configuration |
33
  | `config.json` | Architecture summary (for Hub discoverability) |
34
- | `src/unet_conditional.py` | `ConditionalUNet` module |
35
- | `src/diffusion_conditional.py` | `GaussianDiffusion` (DDPM + DDIM) and the wrapping `ConditionalDiffusionModel` |
36
- | `src/dataset_conditional.py` | Helper for loading CAMELS LH data + label normalisation stats |
37
- | `src/evaluate_conditional.py` | Reference evaluation pipeline (samples + metrics) |
38
  | `inference_example.py` | Runnable example: downloads weights and generates a sample |
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ## Architecture
41
 
42
  Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from
 
26
 
27
  ## Files in this repo
28
 
29
+ **Top level**
30
+
31
  | File | Purpose |
32
  |------|---------|
33
  | `model.pt` | PyTorch checkpoint (state-dict for `ConditionalDiffusionModel`) |
34
  | `args.json` / `args.txt` | Training hyper-parameters and U-Net configuration |
35
  | `config.json` | Architecture summary (for Hub discoverability) |
 
 
 
 
36
  | `inference_example.py` | Runnable example: downloads weights and generates a sample |
37
 
38
+ **`src/` — per-model Python**
39
+
40
+ | File | Purpose |
41
+ |------|---------|
42
+ | `train_conditional.py` | Training entry point (`label_dim=6`, mixed-precision) |
43
+ | `evaluate_conditional.py` | Held-out evaluation: samples + metrics |
44
+ | `eval_model.py` | Lightweight evaluation helper used by the figure scripts |
45
+ | `posterior_inference.py` | Full posterior-inference pipeline (likelihood / sampling) |
46
+ | `figure9_posterior.py` | Paper figure 9 (posterior triangle for the 6-param model) |
47
+ | `plot_r2_cosmology_lhs.py` | Latin-hypercube R² map (μ, σ vs cosmology) |
48
+ | `unet_conditional.py` | `ConditionalUNet` module |
49
+ | `diffusion_conditional.py` | `GaussianDiffusion` (DDPM + DDIM) and the wrapping `ConditionalDiffusionModel` |
50
+ | `dataset_conditional.py` | CAMELS LH dataset loader + label normalisation |
51
+
52
+ **`scripts/shell/` — SLURM launchers**
53
+
54
+ | File | Purpose |
55
+ |------|---------|
56
+ | `train_conditional_lh6.sh` | Submit a training job (`label_dim=6`) |
57
+ | `evaluate_conditional_lh6.sh` | Submit evaluation against the held-out test split |
58
+ | `plot_r2_cosmology_lhs.sh` | Generate the R² cosmology figure |
59
+
60
+ **`cross_model/` — posterior + comparison scripts that use BOTH models**
61
+
62
+ | File | Purpose |
63
+ |------|---------|
64
+ | `compare_posterior_inference.py` (+ `run_compare_posterior.sh`) | End-to-end posterior comparison between 2-param and 6-param emulators |
65
+ | `ddpm_posterior_corrected.py` (+ `scripts/run_ddpm_posterior_corrected.sh`) | Corrected DDPM posterior inference |
66
+ | `poster.py` / `check_poster_env.py` (+ `scripts/run_poster.sh`) | Posterior orchestration and environment check |
67
+ | `submit_vlb_1000grid.py` / `run_vlb_inference_*.sh` | Variational-lower-bound grid inference (200 / 1000 grid) |
68
+ | `scripts/compare_ddpm_models.py` (+ `run_ddpm_comparison.sh`) | DDPM-2 vs DDPM-6 comparison figures |
69
+ | `scripts/ddpm_posterior_six_anchors.py` (+ `run_ddpm_posterior_six_anchors.sh`) | Six-anchor posterior visualisation |
70
+ | `scripts/ddpm_figure6_integration.py`, `figure6_2409_style.py`, `run_ddpm_figure6_suite.py` (+ `run_ddpm_figure6.sh`) | Figure 6 generation pipeline |
71
+ | `scripts/ddpm_triangle_integration.py`, `triangle_plot_posterior.py` (+ `run_triangle_ddpm_both.sh`) | Triangle-plot posterior figures |
72
+ | `scripts/sigma_contour_utils.py` | Confidence-contour helper used by the figure scripts |
73
+ | `scripts/compare_ddpm_training_curves.py` | Parses SLURM logs for combined train/val loss plots |
74
+ | `cross_model/README.md` | How to point these scripts at locally-downloaded weights/data |
75
+
76
+ These cross-model scripts default to the original cluster paths (e.g.
77
+ `/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6`). After downloading
78
+ this repo, supply `--bundle-2param`, `--bundle-6param`, `--data-2param`,
79
+ `--data-6param` to override.
80
+
81
  ## Architecture
82
 
83
  Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from
cross_model/README.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-model scripts
2
+
3
+ These posterior-inference and comparison scripts use BOTH the
4
+ 2-parameter and 6-parameter DDPM checkpoints. Their default
5
+ paths assume the original cluster layout:
6
+
7
+ Models/
8
+ 2param_DDPM_HI_Emulation/ <- code
9
+ 6param_ddpm_hi_lh6/ <- code
10
+ notebook_model_weights/
11
+ 2param_epoch200/ <- checkpoint + args.json
12
+ 6param_best/ <- checkpoint + args.json
13
+
14
+ When running these scripts from a local download of this HF repo,
15
+ pass `--bundle-2param`, `--bundle-6param`, `--data-2param`,
16
+ `--data-6param` (etc.) to point at the locations where you placed
17
+ the weights and the CAMELS LH data. See each script's `--help`.
cross_model/check_poster_env.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """One-shot env check for poster.py — logs NDJSON for debug session."""
3
+ import json
4
+ import os
5
+ import time
6
+
7
+ _LOG = "/scratch/mrpcol001/Diffusion_job/Models/.cursor/debug-a1359c.log"
8
+
9
+
10
+ def _log(hypothesis_id: str, location: str, message: str, data: dict) -> None:
11
+ payload = {
12
+ "sessionId": "a1359c",
13
+ "runId": os.environ.get("DEBUG_POSTER_RUN", "pre-fix"),
14
+ "hypothesisId": hypothesis_id,
15
+ "location": location,
16
+ "message": message,
17
+ "data": data,
18
+ "timestamp": int(time.time() * 1000),
19
+ }
20
+ os.makedirs(os.path.dirname(_LOG), exist_ok=True)
21
+ with open(_LOG, "a", encoding="utf-8") as f:
22
+ f.write(json.dumps(payload) + "\n")
23
+
24
+
25
+ def main() -> None:
26
+ # region agent log
27
+ root = "/scratch/mrpcol001/Diffusion_job/Models"
28
+ cwd = os.getcwd()
29
+ poster_path = os.path.join(root, "poster.py")
30
+ poster_ci = os.path.join(root, "Poster.py")
31
+ _log(
32
+ "H2",
33
+ "check_poster_env.py:main",
34
+ "cwd vs expected Models root",
35
+ {"cwd": cwd, "root": root, "cwd_equals_root": os.path.abspath(cwd) == os.path.abspath(root)},
36
+ )
37
+ _log(
38
+ "H1",
39
+ "check_poster_env.py:main",
40
+ "poster.py presence",
41
+ {
42
+ "poster_py_exists": os.path.isfile(poster_path),
43
+ "poster_path": poster_path,
44
+ "size_if_exists": os.path.getsize(poster_path) if os.path.isfile(poster_path) else None,
45
+ },
46
+ )
47
+ _log(
48
+ "H3",
49
+ "check_poster_env.py:main",
50
+ "case variant",
51
+ {"Poster_py_exists": os.path.isfile(poster_ci)},
52
+ )
53
+ try:
54
+ names = sorted(os.listdir(root))
55
+ except OSError as e:
56
+ names = []
57
+ list_err = str(e)
58
+ else:
59
+ list_err = None
60
+ poster_like = [n for n in names if "poster" in n.lower()]
61
+ _log(
62
+ "H4",
63
+ "check_poster_env.py:main",
64
+ "Models directory poster-related names",
65
+ {"list_error": list_err, "poster_like_filenames": poster_like, "total_entries": len(names)},
66
+ )
67
+ _log(
68
+ "H5",
69
+ "check_poster_env.py:main",
70
+ "alternate runnable scripts hint",
71
+ {"scripts_dir_exists": os.path.isdir(os.path.join(root, "scripts"))},
72
+ )
73
+ # endregion agent log
74
+ print("check_poster_env: wrote logs to", _LOG)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
cross_model/compare_posterior_inference.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ compare_posterior_inference.py
4
+ ==============================
5
+ Side-by-side corrected surrogate posteriors on (Omega_m, sigma_8) for
6
+ DDPM-2 and DDPM-6 using a JOINT P(k) + log-N_HI PDF Gaussian likelihood.
7
+
8
+ Why a joint summary statistic?
9
+ ------------------------------
10
+ P(k) alone leaves Omega_m and sigma_8 strongly degenerate, which is why
11
+ prior runs returned sigma_8 ~ 0.80 +/- 0.12 regardless of truth. The
12
+ column-density PDF carries information that P(k) misses (it is sensitive
13
+ to non-Gaussian amplitude features), so combining the two breaks the
14
+ degeneracy along the S_8 direction.
15
+
16
+ What this script does
17
+ ---------------------
18
+ 1. Loads both DDPM-2 (epoch 200) and DDPM-6 (best) bundles from
19
+ notebook_model_weights/.
20
+ 2. Calibrates per-model sigma_pk and sigma_pdf from validation-set DDPM
21
+ pair scatter (no hard-coded noise).
22
+ 3. For each anchor in the test split:
23
+ - DDPM-2: direct (Omega_m, sigma_8) grid posterior.
24
+ - DDPM-6: Monte-Carlo marginalisation over the four astrophysical
25
+ nuisance parameters with a uniform LHS prior.
26
+ Both posteriors use the JOINT P(k) + PDF likelihood.
27
+ 4. Saves per-anchor posterior arrays as .npz so plots can be re-rendered
28
+ cheaply later.
29
+ 5. Emits a single comparison figure: rows = anchors, columns =
30
+ DDPM-2 | DDPM-6, with 68/95% credible contours, true value, posterior
31
+ mean, and posterior summary annotation.
32
+
33
+ Defaults
34
+ --------
35
+ --grid 30 --ddim-steps 50 --batch-size 8
36
+ --n-pk-samples 8 --n-marg-samples 20 --n-anchors 4
37
+
38
+ Quick smoke test
39
+ ----------------
40
+ python compare_posterior_inference.py --grid 16 --n-pk-samples 4 \\
41
+ --n-marg-samples 5 --n-anchors 2
42
+ """
43
+ from __future__ import annotations
44
+
45
+ import argparse
46
+ import gc
47
+ import sys
48
+ from pathlib import Path
49
+ from typing import Dict, List, Optional, Tuple
50
+
51
+ import matplotlib
52
+ matplotlib.use("Agg")
53
+ import matplotlib.pyplot as plt
54
+ import numpy as np
55
+ import torch
56
+
57
+ MODELS_ROOT = Path(__file__).resolve().parent
58
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
59
+ if str(CODE_6.resolve()) not in sys.path:
60
+ sys.path.insert(0, str(CODE_6))
61
+
62
+ import evaluate_conditional as ec # noqa: E402
63
+ import eval_model as em # noqa: E402
64
+
65
+
66
+ # =============================================================================
67
+ # 1. SUMMARY STATISTICS (P(k) and log-N_HI PDF, per map)
68
+ # =============================================================================
69
+
70
+ def per_map_log_pk(
71
+ imgs: np.ndarray,
72
+ box_size: float = 25.0,
73
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
74
+ """log P(k) for a batch of [0,1] maps. Returns (dk, valid_mask, log_pks).
75
+
76
+ log_pks has shape (N, valid_mask.sum()).
77
+ """
78
+ dk, pks = em.per_map_power_spectra_log(imgs, box_size)
79
+ valid = dk > 0
80
+ return dk, valid, np.log(pks[:, valid] + 1e-30)
81
+
82
+
83
+ def per_map_log_pdf(
84
+ imgs: np.ndarray,
85
+ log_nhi_min: float = 14.0,
86
+ log_nhi_max: float = 22.0,
87
+ n_bins: int = 100,
88
+ ) -> Tuple[np.ndarray, np.ndarray]:
89
+ """log column-density PDF for a batch of [0,1] maps.
90
+
91
+ Returns (bin_centers, log_pdfs). log_pdfs has shape (N, n_bins-1).
92
+ """
93
+ imgs01 = np.clip(imgs, 0.0, 1.0)
94
+ bin_edges = np.linspace(log_nhi_min, log_nhi_max, n_bins)
95
+ bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
96
+ pdfs = []
97
+ for img in imgs01:
98
+ vals = log_nhi_min + (log_nhi_max - log_nhi_min) * img.reshape(-1)
99
+ hist, _ = np.histogram(vals, bins=bin_edges, density=True)
100
+ pdfs.append(hist)
101
+ return bin_centers, np.log(np.stack(pdfs) + 1e-30)
102
+
103
+
104
+ # =============================================================================
105
+ # 2. CALIBRATION OF sigma_pk AND sigma_pdf
106
+ # =============================================================================
107
+
108
+ def calibrate_summary_sigmas(
109
+ model: torch.nn.Module,
110
+ images_val: np.ndarray,
111
+ labels_val: np.ndarray,
112
+ lab_mean: np.ndarray,
113
+ lab_std: np.ndarray,
114
+ normalize: bool,
115
+ device: torch.device,
116
+ box_size: float = 25.0,
117
+ ddim_steps: int = 50,
118
+ n_pairs: int = 30,
119
+ seed: int = 0,
120
+ ) -> Tuple[float, float]:
121
+ """Estimate sigma_pk and sigma_pdf from DDPM aleatoric scatter.
122
+
123
+ For each of n_pairs validation labels, draw two independent DDPM samples
124
+ and measure std(summary_a - summary_b) / sqrt(2). The median over pairs
125
+ gives a robust per-draw noise scale.
126
+ """
127
+ rng = np.random.default_rng(seed)
128
+ n_val = min(n_pairs, len(labels_val))
129
+ idx = rng.choice(len(labels_val), size=n_val, replace=False)
130
+ H, W = int(images_val.shape[-2]), int(images_val.shape[-1])
131
+
132
+ sig_pk: List[float] = []
133
+ sig_pdf: List[float] = []
134
+
135
+ for i in idx:
136
+ lab_pair = np.repeat(labels_val[i:i + 1], 2, axis=0).astype(np.float32)
137
+ pair = em.sample_batch(
138
+ model, lab_pair, lab_mean, lab_std, normalize,
139
+ H, W, device, ddim_steps, False,
140
+ )
141
+ _, _, lpk = per_map_log_pk(pair, box_size)
142
+ _, lpdf = per_map_log_pdf(pair)
143
+ sig_pk.append(float(np.std(lpk[0] - lpk[1]) / np.sqrt(2.0)))
144
+ sig_pdf.append(float(np.std(lpdf[0] - lpdf[1]) / np.sqrt(2.0)))
145
+
146
+ s_pk = max(float(np.median(sig_pk)), 0.01)
147
+ s_pdf = max(float(np.median(sig_pdf)), 0.01)
148
+ print(f" sigma_pk median over {n_val} pairs = {s_pk:.4f}")
149
+ print(f" sigma_pdf median over {n_val} pairs = {s_pdf:.4f}")
150
+ return s_pk, s_pdf
151
+
152
+
153
+ # =============================================================================
154
+ # 3. JOINT LOG-LIKELIHOOD (P(k) + PDF, averaged over DDPM stochasticity)
155
+ # =============================================================================
156
+
157
+ def joint_log_likelihood(
158
+ obs: np.ndarray,
159
+ full_labels: np.ndarray,
160
+ lab_mean: np.ndarray,
161
+ lab_std: np.ndarray,
162
+ normalize: bool,
163
+ model: torch.nn.Module,
164
+ device: torch.device,
165
+ H: int,
166
+ W: int,
167
+ box_size: float,
168
+ ddim_steps: int,
169
+ batch_sz: int,
170
+ n_pk_samples: int,
171
+ sigma_pk: float,
172
+ sigma_pdf: float,
173
+ ) -> np.ndarray:
174
+ """Average log P(k) and log PDF over n_pk_samples DDPM draws per grid pt.
175
+
176
+ Returns log L = log L_pk + log L_pdf at each grid point (shape (ngrid,)).
177
+ """
178
+ # Observed summaries
179
+ _, valid_pk, log_pk_obs = per_map_log_pk(obs[np.newaxis], box_size)
180
+ log_pk_obs = log_pk_obs[0] # (n_valid_pk,)
181
+ _, log_pdf_obs = per_map_log_pdf(obs[np.newaxis])
182
+ log_pdf_obs = log_pdf_obs[0] # (n_pdf_bins,)
183
+
184
+ ngrid = full_labels.shape[0]
185
+ sum_lpk = np.zeros((ngrid, log_pk_obs.size), dtype=np.float64)
186
+ sum_lpdf = np.zeros((ngrid, log_pdf_obs.size), dtype=np.float64)
187
+
188
+ for _s in range(n_pk_samples):
189
+ for j0 in range(0, ngrid, batch_sz):
190
+ chunk = full_labels[j0: j0 + batch_sz]
191
+ imgs = em.sample_batch(
192
+ model, chunk, lab_mean, lab_std, normalize,
193
+ H, W, device, ddim_steps, False,
194
+ )
195
+ _, _, lpk = per_map_log_pk(imgs, box_size)
196
+ _, lpdf = per_map_log_pdf(imgs)
197
+ sum_lpk[j0: j0 + len(chunk)] += lpk
198
+ sum_lpdf[j0: j0 + len(chunk)] += lpdf
199
+
200
+ mean_lpk = sum_lpk / n_pk_samples
201
+ mean_lpdf = sum_lpdf / n_pk_samples
202
+
203
+ mse_pk = np.mean((log_pk_obs[None, :] - mean_lpk) ** 2, axis=1)
204
+ mse_pdf = np.mean((log_pdf_obs[None, :] - mean_lpdf) ** 2, axis=1)
205
+
206
+ return -mse_pk / (2.0 * sigma_pk ** 2) - mse_pdf / (2.0 * sigma_pdf ** 2)
207
+
208
+
209
+ # =============================================================================
210
+ # 4. GRIDS AND PER-MODEL POSTERIORS
211
+ # =============================================================================
212
+
213
+ def cosmo_grid_axes(
214
+ labels_ref: np.ndarray, grid: int, pad_frac: float = 0.02,
215
+ ) -> Tuple[np.ndarray, np.ndarray]:
216
+ lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max())
217
+ lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max())
218
+ p0 = pad_frac * (hi0 - lo0 + 1e-12)
219
+ p1 = pad_frac * (hi1 - lo1 + 1e-12)
220
+ om_ax = np.linspace(lo0 - p0, hi0 + p0, grid, dtype=np.float32)
221
+ s8_ax = np.linspace(lo1 - p1, hi1 + p1, grid, dtype=np.float32)
222
+ return om_ax, s8_ax
223
+
224
+
225
+ def posterior_ddpm2(
226
+ obs: np.ndarray,
227
+ labels_ref: np.ndarray,
228
+ lab_mean: np.ndarray, lab_std: np.ndarray,
229
+ normalize: bool,
230
+ model: torch.nn.Module, device: torch.device,
231
+ grid: int, batch_sz: int, ddim_steps: int,
232
+ n_pk_samples: int,
233
+ sigma_pk: float, sigma_pdf: float,
234
+ box_size: float = 25.0,
235
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
236
+ om_ax, s8_ax = cosmo_grid_axes(labels_ref, grid)
237
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
238
+ full = np.stack([OM.ravel(), S8.ravel()], axis=1).astype(np.float32)
239
+ H, W = int(obs.shape[-2]), int(obs.shape[-1])
240
+
241
+ log_w = joint_log_likelihood(
242
+ obs, full, lab_mean, lab_std, normalize, model, device,
243
+ H, W, box_size, ddim_steps, batch_sz,
244
+ n_pk_samples, sigma_pk, sigma_pdf,
245
+ )
246
+ log_w -= log_w.max()
247
+ w = np.exp(log_w).reshape(grid, grid)
248
+ w /= w.sum()
249
+ return w, OM, S8, om_ax, s8_ax
250
+
251
+
252
+ def posterior_ddpm6(
253
+ obs: np.ndarray,
254
+ labels_ref: np.ndarray,
255
+ lab_mean: np.ndarray, lab_std: np.ndarray,
256
+ normalize: bool,
257
+ model: torch.nn.Module, device: torch.device,
258
+ lo_tail: np.ndarray, hi_tail: np.ndarray,
259
+ grid: int, batch_sz: int, ddim_steps: int,
260
+ n_pk_samples: int, n_marg_samples: int,
261
+ sigma_pk: float, sigma_pdf: float,
262
+ box_size: float = 25.0,
263
+ seed: int = 1,
264
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
265
+ rng = np.random.default_rng(seed)
266
+ om_ax, s8_ax = cosmo_grid_axes(labels_ref, grid)
267
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
268
+ ngrid = OM.size
269
+ H, W = int(obs.shape[-2]), int(obs.shape[-1])
270
+
271
+ log_acc = np.full(ngrid, -np.inf, dtype=np.float64)
272
+ for m in range(n_marg_samples):
273
+ theta_extra = rng.uniform(lo_tail, hi_tail).astype(np.float32)
274
+ full_6 = np.zeros((ngrid, 6), dtype=np.float32)
275
+ full_6[:, 0] = OM.ravel()
276
+ full_6[:, 1] = S8.ravel()
277
+ full_6[:, 2:6] = theta_extra
278
+ log_w = joint_log_likelihood(
279
+ obs, full_6, lab_mean, lab_std, normalize, model, device,
280
+ H, W, box_size, ddim_steps, batch_sz,
281
+ n_pk_samples, sigma_pk, sigma_pdf,
282
+ )
283
+ log_acc = np.logaddexp(log_acc, log_w)
284
+ if (m + 1) % 5 == 0 or (m + 1) == n_marg_samples:
285
+ print(f" DDPM-6 marg draw {m + 1}/{n_marg_samples}")
286
+
287
+ log_acc -= np.log(n_marg_samples)
288
+ log_acc -= log_acc.max()
289
+ w = np.exp(log_acc).reshape(grid, grid)
290
+ w /= w.sum()
291
+ return w, OM, S8, om_ax, s8_ax
292
+
293
+
294
+ # =============================================================================
295
+ # 5. POSTERIOR DIAGNOSTICS
296
+ # =============================================================================
297
+
298
+ def credible_levels(
299
+ w: np.ndarray, levels: Tuple[float, ...] = (0.68, 0.95),
300
+ ) -> List[float]:
301
+ sorted_w = np.sort(w.ravel())[::-1]
302
+ cumsum = np.cumsum(sorted_w)
303
+ out = []
304
+ total = float(sorted_w.sum())
305
+ for L in levels:
306
+ idx = int(np.searchsorted(cumsum, L * total))
307
+ idx = min(idx, len(sorted_w) - 1)
308
+ out.append(float(sorted_w[idx]))
309
+ return out
310
+
311
+
312
+ def posterior_summary(
313
+ w: np.ndarray, OM: np.ndarray, S8: np.ndarray,
314
+ ) -> Dict[str, float]:
315
+ w = w / w.sum()
316
+ mom = float((w * OM).sum())
317
+ ms8 = float((w * S8).sum())
318
+ sm = float(np.sqrt((w * (OM - mom) ** 2).sum()))
319
+ ss = float(np.sqrt((w * (S8 - ms8) ** 2).sum()))
320
+ S8_map = S8 * (OM / 0.3) ** 0.5
321
+ mS8 = float((w * S8_map).sum())
322
+ sS8 = float(np.sqrt((w * (S8_map - mS8) ** 2).sum()))
323
+ n_eff = float(1.0 / (w.ravel() ** 2).sum())
324
+ return dict(
325
+ om_mean=mom, om_std=sm, s8_mean=ms8, s8_std=ss,
326
+ S8_mean=mS8, S8_std=sS8, n_eff=n_eff,
327
+ )
328
+
329
+
330
+ # =============================================================================
331
+ # 6. PLOTTING
332
+ # =============================================================================
333
+
334
+ def plot_panel(
335
+ ax,
336
+ w: np.ndarray, OM: np.ndarray, S8: np.ndarray,
337
+ true_om: float, true_s8: float,
338
+ title: str,
339
+ summary: Dict[str, float],
340
+ cmap: str,
341
+ ) -> None:
342
+ cf = ax.contourf(OM, S8, w, levels=16, cmap=cmap)
343
+ plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
344
+
345
+ try:
346
+ thr68, thr95 = credible_levels(w, levels=(0.68, 0.95))
347
+ ax.contour(
348
+ OM, S8, w, levels=[thr95, thr68],
349
+ colors=["#e07b39", "#c0392b"],
350
+ linewidths=[1.0, 1.6],
351
+ linestyles=["--", "-"],
352
+ )
353
+ except Exception:
354
+ pass
355
+
356
+ ax.scatter(summary["om_mean"], summary["s8_mean"],
357
+ s=70, c="k", marker="+", zorder=8, lw=2.0)
358
+ ax.scatter(true_om, true_s8,
359
+ s=70, c="r", marker="x", zorder=8, lw=2.0)
360
+
361
+ info = (
362
+ f"$\\Omega_m$ = {summary['om_mean']:.3f} $\\pm$ {summary['om_std']:.3f}\n"
363
+ f"$\\sigma_8$ = {summary['s8_mean']:.3f} $\\pm$ {summary['s8_std']:.3f}\n"
364
+ f"$S_8$ = {summary['S8_mean']:.3f} $\\pm$ {summary['S8_std']:.3f}\n"
365
+ f"$n_{{\\rm eff}}$ = {summary['n_eff']:.0f}"
366
+ )
367
+ ax.text(
368
+ 0.02, 0.02, info, transform=ax.transAxes, fontsize=7,
369
+ va="bottom", color="#111",
370
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.78),
371
+ )
372
+
373
+ ax.set_xlabel(r"$\Omega_m$", fontsize=9)
374
+ ax.set_ylabel(r"$\sigma_8$", fontsize=9)
375
+ ax.set_title(title, fontsize=8.5)
376
+
377
+
378
+ def make_comparison_figure(
379
+ per_anchor: List[Dict],
380
+ suptitle: str,
381
+ out_path: Path,
382
+ ) -> None:
383
+ n = len(per_anchor)
384
+ fig, axes = plt.subplots(n, 2, figsize=(11, 4.5 * n), squeeze=False)
385
+ for i, p in enumerate(per_anchor):
386
+ plot_panel(
387
+ axes[i][0],
388
+ p["w2"], p["OM"], p["S8"],
389
+ p["true_om"], p["true_s8"],
390
+ (
391
+ f"DDPM-2 | ix={p['ix']} | "
392
+ rf"$\Omega_m$={p['true_om']:.3f}, $\sigma_8$={p['true_s8']:.3f}"
393
+ ),
394
+ p["summ2"], cmap="Blues",
395
+ )
396
+ plot_panel(
397
+ axes[i][1],
398
+ p["w6"], p["OM"], p["S8"],
399
+ p["true_om"], p["true_s8"],
400
+ (
401
+ f"DDPM-6 (MC marg.) | ix={p['ix']} | "
402
+ rf"$\Omega_m$={p['true_om']:.3f}, $\sigma_8$={p['true_s8']:.3f}"
403
+ ),
404
+ p["summ6"], cmap="Greens",
405
+ )
406
+
407
+ from matplotlib.lines import Line2D
408
+ legend_h = [
409
+ Line2D([], [], marker="x", color="r", ls="", ms=8, label="True"),
410
+ Line2D([], [], marker="+", color="k", ls="", ms=8, label="Posterior mean"),
411
+ Line2D([], [], color="#c0392b", lw=1.6, label="68% CR"),
412
+ Line2D([], [], color="#e07b39", lw=1.0, ls="--", label="95% CR"),
413
+ ]
414
+ fig.legend(
415
+ handles=legend_h, loc="upper center", ncol=4, fontsize=8.5,
416
+ bbox_to_anchor=(0.5, 0.998), frameon=False,
417
+ )
418
+ plt.suptitle(suptitle, fontsize=11, y=0.992)
419
+ plt.tight_layout(rect=(0, 0, 1, 0.97))
420
+ fig.savefig(out_path, dpi=160, bbox_inches="tight")
421
+ plt.close(fig)
422
+ print(f"Saved -> {out_path}")
423
+
424
+
425
+ # =============================================================================
426
+ # 7. MODEL LOADING
427
+ # =============================================================================
428
+
429
+ def load_model(
430
+ args_json: Path, ckpt: Path, device: torch.device,
431
+ ) -> Tuple[torch.nn.Module, Dict]:
432
+ cfg = ec.load_training_config(str(args_json))
433
+ model = ec.build_model(cfg, device)
434
+ ec.load_checkpoint(model, str(ckpt), device)
435
+ model.eval()
436
+ return model, cfg
437
+
438
+
439
+ def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
440
+ for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
441
+ p = data_dir / name
442
+ if p.is_file():
443
+ L = np.load(p)
444
+ if L.shape[1] < 6:
445
+ raise ValueError(
446
+ f"Expected >=6 label columns in {p}, got {L.shape}"
447
+ )
448
+ return (
449
+ L[:, 2:6].min(axis=0).astype(np.float32),
450
+ L[:, 2:6].max(axis=0).astype(np.float32),
451
+ )
452
+ raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
453
+
454
+
455
+ # =============================================================================
456
+ # 8. CLI
457
+ # =============================================================================
458
+
459
+ def parse_args() -> argparse.Namespace:
460
+ p = argparse.ArgumentParser(
461
+ description=(
462
+ "DDPM-2 vs DDPM-6 corrected posteriors on (Omega_m, sigma_8) "
463
+ "with a JOINT P(k) + log-N_HI PDF Gaussian likelihood. "
464
+ "sigma_pk and sigma_pdf are calibrated from DDPM aleatoric noise."
465
+ ),
466
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
467
+ )
468
+ p.add_argument(
469
+ "--output-dir", type=Path,
470
+ default=MODELS_ROOT / "ddpm_posterior_compare_pk_pdf_out",
471
+ )
472
+ p.add_argument(
473
+ "--data-2param", type=Path,
474
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"),
475
+ )
476
+ p.add_argument(
477
+ "--data-6param", type=Path,
478
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"),
479
+ )
480
+ p.add_argument(
481
+ "--bundle-2param", type=Path,
482
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
483
+ )
484
+ p.add_argument(
485
+ "--bundle-6param", type=Path,
486
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
487
+ )
488
+ p.add_argument("--split", default="test", choices=["train", "val", "test"])
489
+ p.add_argument(
490
+ "--n-anchors", type=int, default=4,
491
+ help="Number of evenly-spaced test fields. 4 = compact figure; 6 = wider sweep.",
492
+ )
493
+ p.add_argument("--grid", type=int, default=30)
494
+ p.add_argument("--ddim-steps", type=int, default=50)
495
+ p.add_argument("--batch-size", type=int, default=8)
496
+ p.add_argument(
497
+ "--n-pk-samples", type=int, default=8,
498
+ help="DDPM draws averaged per grid point. Variance ~ 1/N. >=8 recommended.",
499
+ )
500
+ p.add_argument(
501
+ "--n-marg-samples", type=int, default=20,
502
+ help="MC draws over astrophysical params for DDPM-6. >=20 recommended.",
503
+ )
504
+ p.add_argument(
505
+ "--n-calib-pairs", type=int, default=30,
506
+ help="Validation pairs used to calibrate sigma_pk and sigma_pdf.",
507
+ )
508
+ p.add_argument(
509
+ "--sigma-pk", type=float, default=None,
510
+ help="Override calibrated sigma_pk (applied to BOTH models).",
511
+ )
512
+ p.add_argument(
513
+ "--sigma-pdf", type=float, default=None,
514
+ help="Override calibrated sigma_pdf (applied to BOTH models).",
515
+ )
516
+ p.add_argument("--seed", type=int, default=0)
517
+ return p.parse_args()
518
+
519
+
520
+ # =============================================================================
521
+ # 9. MAIN
522
+ # =============================================================================
523
+
524
+ def main() -> None:
525
+ args = parse_args()
526
+ out_dir = args.output_dir.resolve()
527
+ out_dir.mkdir(parents=True, exist_ok=True)
528
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
529
+
530
+ print(f"Device : {device}")
531
+ print(f"Output : {out_dir}")
532
+ print()
533
+
534
+ # ── Data ─────────────────────────────────────────────────────────────
535
+ imgs2, labs2 = ec.load_split(args.data_2param, args.split)
536
+ mean2, std2 = ec.load_label_stats(args.data_2param)
537
+ imgs6, labs6 = ec.load_split(args.data_6param, args.split)
538
+ mean6, std6 = ec.load_label_stats(args.data_6param)
539
+ lo_tail, hi_tail = tail_lhs_bounds(args.data_6param)
540
+
541
+ print(f"DDPM-2 {args.split}: {len(labs2)} maps label_dim={labs2.shape[1]}")
542
+ print(f"DDPM-6 {args.split}: {len(labs6)} maps label_dim={labs6.shape[1]}")
543
+ print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
544
+
545
+ # Pick anchors evenly spread across the test split. Ranges of the two
546
+ # test splits may differ in length, so cap to the smaller.
547
+ n_pool = min(len(labs2), len(labs6))
548
+ n_anchors = max(2, min(args.n_anchors, n_pool))
549
+ anchor_ix = np.linspace(0, n_pool - 1, n_anchors, dtype=int)
550
+ print(f"Anchor indices: {anchor_ix.tolist()}")
551
+ print()
552
+
553
+ # ── Models ───────────────────────────────────────────────────────────
554
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
555
+ aj2 = args.bundle_2param / "args.json"
556
+ ck6 = args.bundle_6param / "best_model.pt"
557
+ aj6 = args.bundle_6param / "args.json"
558
+
559
+ print("Loading DDPM-2 ...")
560
+ model2, cfg2 = load_model(aj2, ck2, device)
561
+ norm2 = bool(cfg2.get("normalize_labels", True))
562
+ print("Loading DDPM-6 ...")
563
+ model6, cfg6 = load_model(aj6, ck6, device)
564
+ norm6 = bool(cfg6.get("normalize_labels", True))
565
+ print()
566
+
567
+ # ── Calibrate sigma_pk and sigma_pdf per model ───────────────────────
568
+ if args.sigma_pk is not None and args.sigma_pdf is not None:
569
+ s2_pk = s6_pk = float(args.sigma_pk)
570
+ s2_pdf = s6_pdf = float(args.sigma_pdf)
571
+ print(f"sigma_pk overridden = {s2_pk:.4f}")
572
+ print(f"sigma_pdf overridden = {s2_pdf:.4f}")
573
+ else:
574
+ print("Calibrating DDPM-2 noise scales ...")
575
+ v2_imgs, v2_labs = ec.load_split(args.data_2param, "val")
576
+ s2_pk, s2_pdf = calibrate_summary_sigmas(
577
+ model2, v2_imgs, v2_labs, mean2, std2, norm2, device,
578
+ ddim_steps=args.ddim_steps,
579
+ n_pairs=args.n_calib_pairs, seed=args.seed,
580
+ )
581
+ del v2_imgs, v2_labs
582
+ gc.collect()
583
+
584
+ print("Calibrating DDPM-6 noise scales ...")
585
+ v6_imgs, v6_labs = ec.load_split(args.data_6param, "val")
586
+ s6_pk, s6_pdf = calibrate_summary_sigmas(
587
+ model6, v6_imgs, v6_labs, mean6, std6, norm6, device,
588
+ ddim_steps=args.ddim_steps,
589
+ n_pairs=args.n_calib_pairs, seed=args.seed + 7,
590
+ )
591
+ del v6_imgs, v6_labs
592
+ gc.collect()
593
+
594
+ if args.sigma_pk is not None:
595
+ s2_pk = s6_pk = float(args.sigma_pk)
596
+ if args.sigma_pdf is not None:
597
+ s2_pdf = s6_pdf = float(args.sigma_pdf)
598
+
599
+ print()
600
+ print(f"DDPM-2: sigma_pk={s2_pk:.4f} sigma_pdf={s2_pdf:.4f}")
601
+ print(f"DDPM-6: sigma_pk={s6_pk:.4f} sigma_pdf={s6_pdf:.4f}")
602
+ print()
603
+
604
+ # ── Per-anchor inference ─────────────────────────────────────────────
605
+ per_anchor: List[Dict] = []
606
+ for k, ix in enumerate(anchor_ix):
607
+ ix = int(ix)
608
+ obs2 = imgs2[ix]
609
+ obs6 = imgs6[ix]
610
+ # DDPM-2 labels carry (Omega_m, sigma_8) directly; DDPM-6 labels[ix]
611
+ # may differ in row ordering between the two splits, so we report the
612
+ # truth from each respective split when forming panel titles.
613
+ true_om = float(labs2[ix, 0])
614
+ true_s8 = float(labs2[ix, 1])
615
+ true_om6 = float(labs6[ix, 0])
616
+ true_s86 = float(labs6[ix, 1])
617
+
618
+ print(
619
+ f"[{k + 1}/{n_anchors}] ix={ix} "
620
+ f"DDPM-2 truth (Om={true_om:.3f}, s8={true_s8:.3f}) "
621
+ f"DDPM-6 truth (Om={true_om6:.3f}, s8={true_s86:.3f})"
622
+ )
623
+
624
+ print(" DDPM-2 posterior ...")
625
+ w2, OM, S8, om_ax, s8_ax = posterior_ddpm2(
626
+ obs2, labs2, mean2, std2, norm2, model2, device,
627
+ args.grid, args.batch_size, args.ddim_steps,
628
+ args.n_pk_samples, s2_pk, s2_pdf,
629
+ )
630
+ summ2 = posterior_summary(w2, OM, S8)
631
+ print(
632
+ f" Om={summ2['om_mean']:.3f}+/-{summ2['om_std']:.3f} "
633
+ f"s8={summ2['s8_mean']:.3f}+/-{summ2['s8_std']:.3f} "
634
+ f"S8={summ2['S8_mean']:.3f}+/-{summ2['S8_std']:.3f} "
635
+ f"n_eff={summ2['n_eff']:.0f}"
636
+ )
637
+
638
+ print(" DDPM-6 posterior (MC marginalisation) ...")
639
+ w6, OM6, S8_6, om_ax6, s8_ax6 = posterior_ddpm6(
640
+ obs6, labs6, mean6, std6, norm6, model6, device,
641
+ lo_tail, hi_tail,
642
+ args.grid, args.batch_size, args.ddim_steps,
643
+ args.n_pk_samples, args.n_marg_samples, s6_pk, s6_pdf,
644
+ seed=args.seed + 1 + k,
645
+ )
646
+ summ6 = posterior_summary(w6, OM6, S8_6)
647
+ print(
648
+ f" Om={summ6['om_mean']:.3f}+/-{summ6['om_std']:.3f} "
649
+ f"s8={summ6['s8_mean']:.3f}+/-{summ6['s8_std']:.3f} "
650
+ f"S8={summ6['S8_mean']:.3f}+/-{summ6['S8_std']:.3f} "
651
+ f"n_eff={summ6['n_eff']:.0f}"
652
+ )
653
+
654
+ per_anchor.append(dict(
655
+ ix=ix,
656
+ true_om=true_om, true_s8=true_s8,
657
+ w2=w2, w6=w6, OM=OM, S8=S8,
658
+ summ2=summ2, summ6=summ6,
659
+ ))
660
+ np.savez_compressed(
661
+ out_dir / f"posterior_ix{ix:03d}.npz",
662
+ w_ddpm2=w2, w_ddpm6=w6,
663
+ om_ax=om_ax, s8_ax=s8_ax,
664
+ true_om_ddpm2=true_om, true_s8_ddpm2=true_s8,
665
+ true_om_ddpm6=true_om6, true_s8_ddpm6=true_s86,
666
+ sigma_pk_ddpm2=s2_pk, sigma_pdf_ddpm2=s2_pdf,
667
+ sigma_pk_ddpm6=s6_pk, sigma_pdf_ddpm6=s6_pdf,
668
+ n_pk_samples=args.n_pk_samples,
669
+ n_marg_samples=args.n_marg_samples,
670
+ ddim_steps=args.ddim_steps,
671
+ )
672
+ gc.collect()
673
+ if torch.cuda.is_available():
674
+ torch.cuda.empty_cache()
675
+
676
+ # ── Comparison figure ────────────────────────────────────────────────
677
+ suptitle = (
678
+ r"Corrected joint-likelihood posteriors $(\Omega_m,\,\sigma_8)$"
679
+ " · "
680
+ r"$P(k)$ + $\log\,N_{\rm HI}$ PDF"
681
+ "\n"
682
+ f"DDPM-2 $\\sigma_{{Pk}}$={s2_pk:.3f}, $\\sigma_{{PDF}}$={s2_pdf:.3f}"
683
+ " · "
684
+ f"DDPM-6 (MC marg., $N_{{\\rm marg}}$={args.n_marg_samples}) "
685
+ f"$\\sigma_{{Pk}}$={s6_pk:.3f}, $\\sigma_{{PDF}}$={s6_pdf:.3f}"
686
+ " · "
687
+ f"grid {args.grid}², {args.n_pk_samples} DDPM draws/pt, "
688
+ f"DDIM {args.ddim_steps} steps"
689
+ )
690
+ make_comparison_figure(
691
+ per_anchor, suptitle,
692
+ out_dir / "compare_posterior_ddpm2_vs_ddpm6.png",
693
+ )
694
+
695
+ print(f"\nDone. Artifacts in {out_dir}")
696
+
697
+
698
+ if __name__ == "__main__":
699
+ main()
cross_model/ddpm_posterior_corrected.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Corrected Surrogate P(k) Bayesian Posteriors on (Omega_m, sigma_8).
4
+
5
+ ═══════════════════════════════════════════════════════════════════
6
+ THEORETICAL CORRECTIONS FROM REVIEW
7
+ ═══════════════════════════════════════════════════════════════════
8
+ 1. STOCHASTIC LIKELIHOOD AVERAGING
9
+ - Single DDPM sample per grid point → n_ddpm_samples averaged draws
10
+ - Reduces emulator variance by 1/sqrt(N), preventing spurious multimodality
11
+
12
+ 2. CALIBRATED LIKELIHOOD NOISE SCALE
13
+ - Hard-coded sigma=0.25 → sigma_pk estimated from validation-set scatter
14
+ - Uses aleatoric uncertainty of the DDPM emulator at fixed theta
15
+
16
+ 3. PROPER MC MARGINALISATION (DDPM-6)
17
+ - Fixed dims 2-5 to extremes → Monte Carlo integral over prior
18
+ - p(Om,s8|d) = integral L(d|Om,s8,theta_extra) pi(theta_extra) dtheta_extra
19
+ - Approximated by uniform draws from the LHS training range
20
+
21
+ 4. HIGHER GRID RESOLUTION
22
+ - 14x14 → 30x30 (900 pts), with optional adaptive refinement
23
+
24
+ 5. VISUALISATION OF PRIOR, LIKELIHOOD, AND POSTERIOR SEPARATELY
25
+ - Shows all three distributions per anchor per model
26
+ - Includes 68%/95% credible contours
27
+ - Includes S8 = sigma_8*(Om/0.3)^0.5 derived parameter
28
+ - Includes effective sample size and posterior predictive check
29
+
30
+ 6. PRIOR-POSTERIOR COMPARISON
31
+ - Explicit uniform prior overlaid on posterior for each anchor
32
+
33
+ USAGE
34
+ ─────
35
+ # Run both models with full corrections
36
+ python ddpm_posterior_corrected.py
37
+
38
+ # DDPM-2 only, faster run
39
+ python ddpm_posterior_corrected.py --ddpm2-only --grid 20 --n-ddpm-samples 4
40
+
41
+ # DDPM-6 only with full marginalisation
42
+ python ddpm_posterior_corrected.py --ddpm6-only --n-marg-samples 30 --n-ddpm-samples 8
43
+ """
44
+
45
+ from __future__ import annotations
46
+
47
+ import argparse
48
+ import gc
49
+ import sys
50
+ import warnings
51
+ from pathlib import Path
52
+ from typing import Dict, List, Optional, Tuple
53
+
54
+ import matplotlib
55
+ matplotlib.use("Agg")
56
+ import matplotlib.gridspec as gridspec
57
+ import matplotlib.pyplot as plt
58
+ import numpy as np
59
+ import torch
60
+
61
+ # Resolve code + data relative to Models/ (same directory as this file)
62
+ MODELS_ROOT = Path(__file__).resolve().parent
63
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
64
+ if str(CODE_6.resolve()) not in sys.path:
65
+ sys.path.insert(0, str(CODE_6))
66
+
67
+ import evaluate_conditional as ec # noqa: E402
68
+ import eval_model as em # noqa: E402
69
+ from figure9_posterior import log_pk_observed # noqa: E402
70
+
71
+
72
+ # ══════════════════════════════════════════════════════════════════════════════
73
+ # 1. DATA UTILITIES
74
+ # ══════════════════════════════════════════════════════════════════════════════
75
+
76
+ def _train_label_path(data_dir: Path) -> Path:
77
+ for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
78
+ p = data_dir / name
79
+ if p.is_file():
80
+ return p
81
+ raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
82
+
83
+
84
+ def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Min/max of the LHS training distribution for label dims 2-5.
87
+ These define the UNIFORM PRIOR for the astrophysical nuisance parameters.
88
+ """
89
+ L = np.load(_train_label_path(data_dir))
90
+ if L.shape[1] < 6:
91
+ raise ValueError(f"Expected >= 6 label columns, got {L.shape}")
92
+ lo = L[:, 2:6].min(axis=0).astype(np.float32)
93
+ hi = L[:, 2:6].max(axis=0).astype(np.float32)
94
+ return lo, hi
95
+
96
+
97
+ def cosmo_prior_bounds(labels_split: np.ndarray) -> Tuple[float, float, float, float]:
98
+ """Return (om_lo, om_hi, s8_lo, s8_hi) from the training set LHS range."""
99
+ om_lo = float(labels_split[:, 0].min())
100
+ om_hi = float(labels_split[:, 0].max())
101
+ s8_lo = float(labels_split[:, 1].min())
102
+ s8_hi = float(labels_split[:, 1].max())
103
+ return om_lo, om_hi, s8_lo, s8_hi
104
+
105
+
106
+ # ══════════════════════════════════════════════════════════════════════════════
107
+ # 2. LIKELIHOOD CALIBRATION
108
+ # ══════════════════════════════════════════════════════════════════════════════
109
+
110
+ def calibrate_sigma_pk(
111
+ data_dir: Path,
112
+ model: torch.nn.Module,
113
+ lab_mean: np.ndarray,
114
+ lab_std: np.ndarray,
115
+ normalize: bool,
116
+ H: int,
117
+ W: int,
118
+ device: torch.device,
119
+ ddim_steps: int,
120
+ n_pairs: int = 60,
121
+ rng_seed: int = 0,
122
+ ) -> float:
123
+ """
124
+ Estimate sigma_pk = aleatoric std of log P(k) from DDPM stochasticity.
125
+
126
+ For n_pairs randomly chosen validation-set labels, draw 2 DDPM samples
127
+ and measure std(log Pk_a - log Pk_b) / sqrt(2). The median over pairs
128
+ gives a robust noise floor for the likelihood.
129
+
130
+ This replaces the hard-coded sigma = 0.25 with a data-driven estimate.
131
+ """
132
+ print(f" Calibrating sigma_pk from {n_pairs} validation-set pairs ...")
133
+ images_val, labels_val = ec.load_split(data_dir, "val")
134
+ n_val = len(labels_val)
135
+ rng = np.random.default_rng(rng_seed)
136
+ idx = rng.choice(n_val, size=min(n_pairs, n_val), replace=False)
137
+
138
+ sigmas = []
139
+ for i in idx:
140
+ lab = np.repeat(labels_val[i : i + 1], 2, axis=0).astype(np.float32)
141
+ pair = em.sample_batch(
142
+ model, lab, lab_mean, lab_std, normalize,
143
+ H, W, device, ddim_steps, False,
144
+ )
145
+ _, pk_pair = em.per_map_power_spectra_log(pair, 25.0)
146
+ valid = pk_pair[0] > 0
147
+ log_pk_pair = np.log(pk_pair[:, valid] + 1e-30)
148
+ diff_std = float(np.std(log_pk_pair[0] - log_pk_pair[1])) / np.sqrt(2.0)
149
+ sigmas.append(diff_std)
150
+
151
+ sigma_cal = float(np.median(sigmas))
152
+ sigma_cal = max(sigma_cal, 0.05) # lower-bound: prevent degenerate likelihoods
153
+ print(f" Calibrated sigma_pk = {sigma_cal:.4f} (was hard-coded 0.25)")
154
+ return sigma_cal
155
+
156
+
157
+ # ══════════════════════════════════════════════════════════════════════════════
158
+ # 3. GRID CONSTRUCTION
159
+ # ══════════════════════════════════════════════════════════════════════════════
160
+
161
+ def build_cosmo_axes(
162
+ labels_split: np.ndarray, grid: int, pad_frac: float = 0.02
163
+ ) -> Tuple[np.ndarray, np.ndarray]:
164
+ """Return (om_ax, s8_ax) with `grid` equally-spaced points inside the LHS range."""
165
+ om_lo, om_hi, s8_lo, s8_hi = cosmo_prior_bounds(labels_split)
166
+ pad0 = pad_frac * (om_hi - om_lo + 1e-12)
167
+ pad1 = pad_frac * (s8_hi - s8_lo + 1e-12)
168
+ om_ax = np.linspace(om_lo - pad0, om_hi + pad0, grid)
169
+ s8_ax = np.linspace(s8_lo - pad1, s8_hi + pad1, grid)
170
+ return om_ax, s8_ax
171
+
172
+
173
+ def build_full_grid(
174
+ om_ax: np.ndarray,
175
+ s8_ax: np.ndarray,
176
+ tail: Optional[np.ndarray] = None,
177
+ lab_dim: int = 2,
178
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
179
+ """
180
+ Build a flattened (grid^2, lab_dim) array for a sweep over (Om, s8).
181
+ If tail is provided (shape (4,)), dims 2-5 are fixed to those values.
182
+ Returns (full_labels, OM_meshgrid, S8_meshgrid).
183
+ """
184
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
185
+ ngrid = OM.size
186
+ out = np.zeros((ngrid, lab_dim), dtype=np.float32)
187
+ out[:, 0] = OM.ravel()
188
+ out[:, 1] = S8.ravel()
189
+ if tail is not None:
190
+ assert tail.shape == (4,), f"Expected tail shape (4,), got {tail.shape}"
191
+ out[:, 2:6] = tail[np.newaxis, :]
192
+ return out, OM, S8
193
+
194
+
195
+ # ══════════════════════════════════════════════════════════════════════════════
196
+ # 4. LOG-LIKELIHOOD (AVERAGED OVER DDPM STOCHASTICITY)
197
+ # ══════════════════════════════════════════════════════════════════════════════
198
+
199
+ def compute_log_likelihood(
200
+ obs: np.ndarray,
201
+ full: np.ndarray,
202
+ lab_mean: np.ndarray,
203
+ lab_std: np.ndarray,
204
+ normalize: bool,
205
+ model: torch.nn.Module,
206
+ H: int,
207
+ W: int,
208
+ device: torch.device,
209
+ ddim_steps: int,
210
+ batch_sz: int,
211
+ n_ddpm_samples: int,
212
+ sigma_pk: float,
213
+ ) -> np.ndarray:
214
+ """
215
+ Return log-likelihood array of shape (ngrid,), where each entry is:
216
+
217
+ ln L(d | theta_i) ≈ - (1 / 2*sigma^2) * mean_k[ (log Pk_obs - mean_j[log Pk_gen_j]) ]^2
218
+
219
+ The mean over j=1..n_ddpm_samples suppresses DDPM stochasticity.
220
+
221
+ Parameters
222
+ ----------
223
+ sigma_pk : calibrated noise scale on log P(k) — NOT hard-coded.
224
+ n_ddpm_samples : number of independent DDPM draws to average per grid pt.
225
+ """
226
+ ngrid = full.shape[0]
227
+
228
+ npix = int(obs.shape[-1])
229
+ dl = 25.0 / npix
230
+ logf = em.images01_to_log_nhi(obs)
231
+ dk, _ = ec.PowerSpectrum(logf, N=npix, dl=dl)
232
+ valid = dk > 0
233
+ # log_pk_observed returns values only at dk > 0 (same length as valid.sum())
234
+ log_pd = log_pk_observed(obs, 25.0, dk)
235
+
236
+ accumulated = []
237
+ for _s in range(n_ddpm_samples):
238
+ sample_log_pk = []
239
+ for j0 in range(0, ngrid, batch_sz):
240
+ chunk = full[j0: j0 + batch_sz]
241
+ imgs = em.sample_batch(
242
+ model, chunk, lab_mean, lab_std, normalize,
243
+ H, W, device, ddim_steps, False,
244
+ )
245
+ _, pkc = em.per_map_power_spectra_log(imgs, 25.0)
246
+ sample_log_pk.append(np.log(pkc[:, valid] + 1e-30))
247
+ accumulated.append(np.concatenate(sample_log_pk, axis=0)) # (ngrid, nk)
248
+
249
+ mean_log_pg = np.mean(accumulated, axis=0) # (ngrid, nk)
250
+
251
+ mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1) # (ngrid,)
252
+ log_like = -mse / (2.0 * sigma_pk ** 2)
253
+ return log_like
254
+
255
+
256
+ # ══════════════════════════════════════════════════════════════════════════════
257
+ # 5. MARGINALISATION OVER ASTROPHYSICAL PARAMETERS (DDPM-6)
258
+ # ══════════════════════════════════════════════════════════════════════════════
259
+
260
+ def marginal_log_likelihood_ddpm6(
261
+ obs: np.ndarray,
262
+ om_ax: np.ndarray,
263
+ s8_ax: np.ndarray,
264
+ lo_tail: np.ndarray,
265
+ hi_tail: np.ndarray,
266
+ lab_mean: np.ndarray,
267
+ lab_std: np.ndarray,
268
+ normalize: bool,
269
+ model: torch.nn.Module,
270
+ H: int,
271
+ W: int,
272
+ device: torch.device,
273
+ ddim_steps: int,
274
+ batch_sz: int,
275
+ n_ddpm_samples: int,
276
+ n_marg_samples: int,
277
+ sigma_pk: float,
278
+ rng_seed: int = 42,
279
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
280
+ """
281
+ Correct 2D marginal log-likelihood over (Om, s8) for the 6-param DDPM.
282
+
283
+ Implements Monte Carlo integration:
284
+ ln L_marg(d | Om, s8) = log[ (1/N) Σ_i L(d | Om, s8, θ_extra^i) ]
285
+ = log-sum-exp [ ln L(d | Om, s8, θ_extra^i) ] - ln N
286
+
287
+ where θ_extra^i ~ Uniform(lo_tail, hi_tail) [the prior over dims 2-5]
288
+
289
+ Returns (log_like_marginal, OM, S8) with log_like_marginal shaped (ngrid,).
290
+ """
291
+ rng = np.random.default_rng(rng_seed)
292
+ theta_extras = rng.uniform(
293
+ lo_tail, hi_tail, size=(n_marg_samples, 4)
294
+ ).astype(np.float32)
295
+
296
+ ngrid = len(om_ax) * len(s8_ax)
297
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
298
+
299
+ log_like_accumulator = np.full(ngrid, -np.inf, dtype=np.float64)
300
+
301
+ for mc_i, theta_extra in enumerate(theta_extras):
302
+ print(f" MC marginalisation draw {mc_i+1}/{n_marg_samples} ...", end="\r", flush=True)
303
+ full, _, _ = build_full_grid(om_ax, s8_ax, tail=theta_extra, lab_dim=6)
304
+ lnL_i = compute_log_likelihood(
305
+ obs, full, lab_mean, lab_std, normalize, model,
306
+ H, W, device, ddim_steps, batch_sz, n_ddpm_samples, sigma_pk,
307
+ )
308
+ log_like_accumulator = np.logaddexp(log_like_accumulator, lnL_i)
309
+
310
+ log_like_marginal = log_like_accumulator - np.log(n_marg_samples)
311
+ print(flush=True)
312
+ return log_like_marginal, OM, S8
313
+
314
+
315
+ # ══════════════════════════════════════════════════════════════════════════════
316
+ # 6. POSTERIOR COMPUTATION & DIAGNOSTICS
317
+ # ══════════════════════════════════════════════════════════════════════════════
318
+
319
+ def log_like_to_posterior(
320
+ log_like: np.ndarray,
321
+ grid: int,
322
+ ) -> Tuple[np.ndarray, float]:
323
+ """Flat prior on grid → normalized weights + n_eff."""
324
+ log_like = log_like - log_like.max() # numerical stability
325
+ weights = np.exp(log_like).reshape(grid, grid)
326
+ weights /= weights.sum()
327
+ n_eff = 1.0 / float(np.sum(weights ** 2))
328
+ return weights, n_eff
329
+
330
+
331
+ def credible_contour_levels(
332
+ weights: np.ndarray, credible_levels=(0.68, 0.95)
333
+ ) -> List[float]:
334
+ """Highest-density-style thresholds containing `level` of total mass."""
335
+ flat = weights.ravel()
336
+ total_mass = float(flat.sum())
337
+ sorted_desc = np.sort(flat)[::-1]
338
+ cumsum = np.cumsum(sorted_desc)
339
+ thresholds = []
340
+ for cl in credible_levels:
341
+ target = cl * total_mass
342
+ idx = int(np.searchsorted(cumsum, target))
343
+ idx = min(idx, len(sorted_desc) - 1)
344
+ thresholds.append(float(sorted_desc[idx]))
345
+ return thresholds
346
+
347
+
348
+ def posterior_summary(weights: np.ndarray, OM: np.ndarray, S8: np.ndarray) -> Dict:
349
+ """Posterior mean, std, and S8 = sigma_8*(Om/0.3)^0.5 statistics."""
350
+ mom = float((weights * OM).sum())
351
+ ms8 = float((weights * S8).sum())
352
+ var_om = float((weights * (OM - mom)**2).sum())
353
+ var_s8 = float((weights * (S8 - ms8)**2).sum())
354
+ S8_map = S8 * (OM / 0.3) ** 0.5
355
+ mS8 = float((weights * S8_map).sum())
356
+ var_S8 = float((weights * (S8_map - mS8)**2).sum())
357
+ return dict(
358
+ om_mean=mom, om_std=np.sqrt(var_om),
359
+ s8_mean=ms8, s8_std=np.sqrt(var_s8),
360
+ S8_mean=mS8, S8_std=np.sqrt(var_S8),
361
+ )
362
+
363
+
364
+ # ══════════════════════════════════════════════════════════════════════════════
365
+ # 7. POSTERIOR PREDICTIVE CHECK
366
+ # ══════════════════════════════════════════════════════════════════════════════
367
+
368
+ def posterior_predictive_check(
369
+ obs: np.ndarray,
370
+ weights: np.ndarray,
371
+ OM: np.ndarray,
372
+ S8: np.ndarray,
373
+ model: torch.nn.Module,
374
+ lab_mean: np.ndarray,
375
+ lab_std: np.ndarray,
376
+ normalize: bool,
377
+ H: int,
378
+ W: int,
379
+ device: torch.device,
380
+ ddim_steps: int,
381
+ n_draws: int = 30,
382
+ rng_seed: int = 7,
383
+ lab_dim: int = 2,
384
+ lo_tail: Optional[np.ndarray] = None,
385
+ hi_tail: Optional[np.ndarray] = None,
386
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
387
+ """
388
+ Draw parameters from the posterior and generate maps.
389
+ Returns (k_valid, log_pk_obs_valid, ppc_lo68, ppc_hi68, ppc_lo95, ppc_hi95),
390
+ all aligned to dk > 0 bins only.
391
+ """
392
+ rng = np.random.default_rng(rng_seed)
393
+ flat_w = weights.ravel()
394
+ flat_om = OM.ravel()
395
+ flat_s8 = S8.ravel()
396
+ idx = rng.choice(len(flat_w), size=n_draws, replace=True, p=flat_w)
397
+
398
+ npix = int(obs.shape[-1])
399
+ dl = 25.0 / npix
400
+ logf = em.images01_to_log_nhi(obs)
401
+ dk, _ = ec.PowerSpectrum(logf, N=npix, dl=dl)
402
+ valid = dk > 0
403
+ log_pk_obs = log_pk_observed(obs, 25.0, dk)
404
+
405
+ ppc_log_pks = []
406
+ for i in idx:
407
+ if lab_dim == 2:
408
+ theta = np.array([[flat_om[i], flat_s8[i]]], dtype=np.float32)
409
+ else:
410
+ assert lo_tail is not None and hi_tail is not None
411
+ te = rng.uniform(lo_tail, hi_tail).astype(np.float32)
412
+ theta = np.array([[flat_om[i], flat_s8[i], *te]], dtype=np.float32)
413
+ img = em.sample_batch(
414
+ model, theta, lab_mean, lab_std, normalize,
415
+ H, W, device, ddim_steps, False,
416
+ )
417
+ _, pkc = em.per_map_power_spectra_log(img, 25.0)
418
+ ppc_log_pks.append(np.log(pkc[0, valid] + 1e-30))
419
+
420
+ ppc_arr = np.array(ppc_log_pks) # (n_draws, n_valid)
421
+ return (
422
+ dk[valid],
423
+ log_pk_obs,
424
+ np.percentile(ppc_arr, 16, axis=0),
425
+ np.percentile(ppc_arr, 84, axis=0),
426
+ np.percentile(ppc_arr, 2.5, axis=0),
427
+ np.percentile(ppc_arr, 97.5, axis=0),
428
+ )
429
+
430
+
431
+ # ══════════════════════════════════════════════════════════════════════════════
432
+ # 8. VISUALISATION
433
+ # ══════════════════════════════════════════════════════════════════════════════
434
+
435
+ CMAP_PRIOR = "Greys"
436
+ CMAP_LIKE = "YlOrRd"
437
+ CMAP_POST = "Blues"
438
+
439
+
440
+ def _uniform_prior(OM: np.ndarray, S8: np.ndarray) -> np.ndarray:
441
+ """Flat prior: uniform weight over the grid."""
442
+ prior = np.ones_like(OM)
443
+ return prior / prior.sum()
444
+
445
+
446
+ def plot_prior_likelihood_posterior_panel(
447
+ fig,
448
+ gs_row,
449
+ weights: np.ndarray,
450
+ log_like: np.ndarray,
451
+ OM: np.ndarray,
452
+ S8: np.ndarray,
453
+ true_om: float,
454
+ true_s8: float,
455
+ anchor_ix: int,
456
+ summary: Dict,
457
+ n_eff: float,
458
+ title_suffix: str = "",
459
+ ) -> None:
460
+ """
461
+ Plot three side-by-side panels for one anchor:
462
+ [0] Uniform prior [1] Normalised likelihood [2] Posterior
463
+ Each panel shows 68% / 95% credible contours where applicable.
464
+ """
465
+ grid = weights.shape[0]
466
+ prior = _uniform_prior(OM, S8)
467
+
468
+ like = np.exp(log_like - log_like.max()).reshape(grid, grid)
469
+ like /= like.sum()
470
+
471
+ panels = [
472
+ (prior, CMAP_PRIOR, r"Uniform Prior $\pi(\Omega_m, \sigma_8)$"),
473
+ (like, CMAP_LIKE, r"Normalised Likelihood $\mathcal{L}(\mathbf{d}|\theta)$"),
474
+ (weights, CMAP_POST, r"Posterior $p(\theta|\mathbf{d})$"),
475
+ ]
476
+
477
+ for col, (Wmap, cmap, label) in enumerate(panels):
478
+ ax = fig.add_subplot(gs_row[col])
479
+ cf = ax.contourf(OM, S8, Wmap, levels=14, cmap=cmap)
480
+ plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
481
+
482
+ if col > 0:
483
+ lvls = credible_contour_levels(Wmap)
484
+ try:
485
+ ax.contour(OM, S8, Wmap, levels=lvls,
486
+ colors=["white", "cyan"],
487
+ linewidths=[1.0, 0.6],
488
+ linestyles=["solid", "dashed"])
489
+ except Exception:
490
+ pass
491
+
492
+ ax.scatter(true_om, true_s8, s=70, c="red",
493
+ marker="x", zorder=8, linewidths=2.0, label="True")
494
+ if col == 2:
495
+ ax.scatter(summary["om_mean"], summary["s8_mean"],
496
+ s=80, c="black", marker="+",
497
+ zorder=8, linewidths=2.0, label="Post. mean")
498
+ ax.legend(fontsize=7, loc="upper right")
499
+
500
+ if col == 2:
501
+ txt = (
502
+ f"$\\Omega_m$: {summary['om_mean']:.3f} ± {summary['om_std']:.3f}\n"
503
+ f"$\\sigma_8$: {summary['s8_mean']:.3f} ± {summary['s8_std']:.3f}\n"
504
+ f"$S_8$: {summary['S8_mean']:.3f} ± {summary['S8_std']:.3f}\n"
505
+ f"$n_\\mathrm{{eff}}$: {n_eff:.0f}"
506
+ )
507
+ ax.text(0.02, 0.02, txt, transform=ax.transAxes,
508
+ fontsize=6.5, va="bottom", color="#111",
509
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.75))
510
+
511
+ ax.set_xlabel(r"$\Omega_m$", fontsize=9)
512
+ ax.set_ylabel(r"$\sigma_8$", fontsize=9)
513
+ ax.set_title(
514
+ f"ix={anchor_ix} | {label}{title_suffix}",
515
+ fontsize=8, pad=4
516
+ )
517
+
518
+
519
+ def plot_ppc_panel(
520
+ ax,
521
+ k_valid: np.ndarray,
522
+ log_pk_obs: np.ndarray,
523
+ ppc_lo68: np.ndarray,
524
+ ppc_hi68: np.ndarray,
525
+ ppc_lo95: np.ndarray,
526
+ ppc_hi95: np.ndarray,
527
+ anchor_ix: int,
528
+ ) -> None:
529
+ """Posterior predictive check on log P(k) at valid k bins."""
530
+ ax.fill_between(k_valid, ppc_lo95, ppc_hi95,
531
+ alpha=0.18, color="steelblue", label="95% PPC")
532
+ ax.fill_between(k_valid, ppc_lo68, ppc_hi68,
533
+ alpha=0.40, color="steelblue", label="68% PPC")
534
+ ax.plot(k_valid, log_pk_obs, "k-", lw=1.8, label="Observed")
535
+ ax.set_xscale("log")
536
+ ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8)
537
+ ax.set_ylabel(r"$\log P_\mathrm{HI}(k)$", fontsize=8)
538
+ ax.set_title(f"Posterior Predictive Check (ix={anchor_ix})", fontsize=8)
539
+ ax.legend(fontsize=7)
540
+ ax.grid(alpha=0.25)
541
+
542
+
543
+
544
+ def plot_s8_marginal(
545
+ ax, weights: np.ndarray, OM: np.ndarray, S8: np.ndarray,
546
+ true_om: float, true_s8: float, anchor_ix: int,
547
+ ) -> None:
548
+ """1D marginal of the derived S8 = sigma_8*(Om/0.3)^0.5 parameter."""
549
+ S8_map = S8 * (OM / 0.3) ** 0.5
550
+ true_S8 = true_s8 * (true_om / 0.3) ** 0.5
551
+ S8_flat = S8_map.ravel()
552
+ w_flat = weights.ravel()
553
+ s8_bins = np.linspace(S8_flat.min(), S8_flat.max(), 40)
554
+ hist = np.zeros(len(s8_bins) - 1)
555
+ for i, (lo, hi) in enumerate(zip(s8_bins[:-1], s8_bins[1:])):
556
+ mask = (S8_flat >= lo) & (S8_flat < hi)
557
+ hist[i] = w_flat[mask].sum()
558
+ hist /= hist.sum() * (s8_bins[1] - s8_bins[0])
559
+ centers = 0.5 * (s8_bins[:-1] + s8_bins[1:])
560
+ ax.bar(centers, hist, width=(s8_bins[1] - s8_bins[0]),
561
+ color="steelblue", alpha=0.7, label="Posterior")
562
+ ax.axvline(true_S8, color="red", lw=1.5, ls="--", label=f"True $S_8$={true_S8:.3f}")
563
+ ax.set_xlabel(r"$S_8 = \sigma_8(\Omega_m/0.3)^{0.5}$", fontsize=8)
564
+ ax.set_ylabel("Prob. density", fontsize=8)
565
+ ax.set_title(f"$S_8$ marginal (ix={anchor_ix})", fontsize=8)
566
+ ax.legend(fontsize=7)
567
+ ax.grid(alpha=0.25)
568
+
569
+
570
+ # ══════════════════════════════════════════════════════════════════════════════
571
+ # 9. PER-MODEL RUNNER
572
+ # ══════════════════════════════════════════════════════════════════════════════
573
+
574
+ def run_model(
575
+ out_dir: Path,
576
+ *,
577
+ model_name: str,
578
+ model: torch.nn.Module,
579
+ cfg: Dict,
580
+ images: np.ndarray,
581
+ labels: np.ndarray,
582
+ lab_mean: np.ndarray,
583
+ lab_std: np.ndarray,
584
+ anchor_ix: np.ndarray,
585
+ grid: int,
586
+ ddim_steps: int,
587
+ batch_sz: int,
588
+ n_ddpm_samples: int,
589
+ n_marg_samples: int,
590
+ sigma_pk: float,
591
+ lo_tail: Optional[np.ndarray] = None,
592
+ hi_tail: Optional[np.ndarray] = None,
593
+ do_ppc: bool = True,
594
+ ) -> None:
595
+ normalize = bool(cfg.get("normalize_labels", True))
596
+ H, W = int(images.shape[-2]), int(images.shape[-1])
597
+ lab_dim = 6 if model_name == "DDPM-6" else 2
598
+ n_anchors = len(anchor_ix.ravel())
599
+ is_6param = lab_dim == 6
600
+ device = next(model.parameters()).device
601
+
602
+ fig1 = plt.figure(figsize=(17, 4.8 * n_anchors))
603
+ outer_gs = gridspec.GridSpec(n_anchors, 1, figure=fig1, hspace=0.55)
604
+
605
+ fig2, axes2 = plt.subplots(
606
+ n_anchors, 2, figsize=(12, 4.5 * n_anchors), squeeze=False
607
+ )
608
+
609
+ om_ax, s8_ax = build_cosmo_axes(labels, grid)
610
+
611
+ for k, ix in enumerate(anchor_ix.ravel()):
612
+ obs = images[ix]
613
+ lab_t = labels[ix].astype(np.float32)
614
+ true_om, true_s8 = float(lab_t[0]), float(lab_t[1])
615
+
616
+ print(f"\n[{model_name}] Anchor ix={ix} "
617
+ f"(Ωm={true_om:.3f}, σ8={true_s8:.3f})")
618
+
619
+ if is_6param:
620
+ print(" MC marginalisation over dims 2-5 ...")
621
+ log_like, OM, S8 = marginal_log_likelihood_ddpm6(
622
+ obs, om_ax, s8_ax, lo_tail, hi_tail,
623
+ lab_mean, lab_std, normalize, model,
624
+ H, W, device=device,
625
+ ddim_steps=ddim_steps, batch_sz=batch_sz,
626
+ n_ddpm_samples=n_ddpm_samples,
627
+ n_marg_samples=n_marg_samples,
628
+ sigma_pk=sigma_pk,
629
+ )
630
+ else:
631
+ print(f" Computing log-likelihood ({n_ddpm_samples} DDPM draws per point) ...")
632
+ full, OM, S8 = build_full_grid(om_ax, s8_ax, tail=None, lab_dim=2)
633
+ log_like = compute_log_likelihood(
634
+ obs, full, lab_mean, lab_std, normalize, model,
635
+ H, W, device, ddim_steps, batch_sz, n_ddpm_samples, sigma_pk,
636
+ )
637
+
638
+ weights, n_eff = log_like_to_posterior(log_like, grid)
639
+ summary = posterior_summary(weights, OM, S8)
640
+
641
+ print(f" n_eff = {n_eff:.1f} / {grid**2} grid points")
642
+ print(f" Ωm posterior: {summary['om_mean']:.3f} ± {summary['om_std']:.3f} "
643
+ f"(true: {true_om:.3f})")
644
+ print(f" σ8 posterior: {summary['s8_mean']:.3f} ± {summary['s8_std']:.3f} "
645
+ f"(true: {true_s8:.3f})")
646
+ print(f" S8 posterior: {summary['S8_mean']:.3f} ± {summary['S8_std']:.3f}")
647
+
648
+ if n_eff < 20:
649
+ warnings.warn(
650
+ f"n_eff={n_eff:.1f} is very low for ix={ix}. "
651
+ "Increase n_ddpm_samples or grid resolution.",
652
+ stacklevel=2,
653
+ )
654
+
655
+ inner_gs = gridspec.GridSpecFromSubplotSpec(
656
+ 1, 3, subplot_spec=outer_gs[k], wspace=0.38
657
+ )
658
+ marg_note = " (marginalised)" if is_6param else ""
659
+ plot_prior_likelihood_posterior_panel(
660
+ fig1, inner_gs, weights, log_like.reshape(grid, grid),
661
+ OM, S8, true_om, true_s8, ix, summary, n_eff,
662
+ title_suffix=marg_note,
663
+ )
664
+
665
+ if do_ppc:
666
+ dk_v, log_pk_obs, plo68, phi68, plo95, phi95 = posterior_predictive_check(
667
+ obs, weights, OM, S8, model,
668
+ lab_mean, lab_std, normalize,
669
+ H, W, device, ddim_steps,
670
+ n_draws=20, lab_dim=lab_dim,
671
+ lo_tail=lo_tail, hi_tail=hi_tail,
672
+ )
673
+ plot_ppc_panel(
674
+ axes2[k, 0], dk_v, log_pk_obs, plo68, phi68, plo95, phi95, ix,
675
+ )
676
+ else:
677
+ axes2[k, 0].axis("off")
678
+ axes2[k, 0].text(
679
+ 0.5, 0.5, "PPC disabled",
680
+ ha="center", va="center", transform=axes2[k, 0].transAxes,
681
+ )
682
+
683
+ plot_s8_marginal(axes2[k, 1], weights, OM, S8, true_om, true_s8, ix)
684
+
685
+ fig1.suptitle(
686
+ f"{model_name} — Prior · Likelihood · Posterior on "
687
+ r"$(\Omega_m,\,\sigma_8)$ — six CAMELS anchors"
688
+ f"\n[sigma_pk={sigma_pk:.3f}, n_ddpm={n_ddpm_samples}, grid={grid}×{grid}"
689
+ + (f", n_marg={n_marg_samples}]" if is_6param else "]"),
690
+ fontsize=12, y=1.001,
691
+ )
692
+ p1 = out_dir / f"{model_name.replace('-', '')}_prior_likelihood_posterior.png"
693
+ fig1.savefig(p1, dpi=160, bbox_inches="tight")
694
+ plt.close(fig1)
695
+ print(f"\nSaved → {p1}")
696
+
697
+ fig2.suptitle(
698
+ f"{model_name} — Posterior Predictive Checks & $S_8$ Marginals",
699
+ fontsize=12, y=1.002,
700
+ )
701
+ fig2.tight_layout()
702
+ p2 = out_dir / f"{model_name.replace('-', '')}_ppc_s8.png"
703
+ fig2.savefig(p2, dpi=160, bbox_inches="tight")
704
+ plt.close(fig2)
705
+ print(f"Saved → {p2}")
706
+
707
+
708
+ def load_model(args_json: Path, ckpt: Path, device: torch.device):
709
+ cfg = ec.load_training_config(str(args_json))
710
+ model = ec.build_model(cfg, device)
711
+ ec.load_checkpoint(model, str(ckpt), device)
712
+ model.eval()
713
+ return model, cfg
714
+
715
+
716
+ def parse_args() -> argparse.Namespace:
717
+ p = argparse.ArgumentParser(
718
+ description=(
719
+ "Corrected Bayesian posteriors on (Ωm, σ8): averaged DDPM likelihood, "
720
+ "calibrated sigma_pk, DDPM-6 MC marginalisation, prior/likelihood/posterior plots."
721
+ ),
722
+ formatter_class=argparse.RawDescriptionHelpFormatter,
723
+ )
724
+ p.add_argument("--output-dir", type=Path,
725
+ default=MODELS_ROOT / "ddpm_posterior_corrected_fullviz_out")
726
+ p.add_argument("--data-2param", type=Path,
727
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
728
+ p.add_argument("--data-6param", type=Path,
729
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
730
+ p.add_argument("--bundle-2param", type=Path,
731
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200")
732
+ p.add_argument("--bundle-6param", type=Path,
733
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best")
734
+ p.add_argument("--split", type=str, default="test",
735
+ choices=["train", "val", "test"])
736
+ p.add_argument("--grid", type=int, default=30)
737
+ p.add_argument("--ddim-steps", type=int, default=50)
738
+ p.add_argument("--batch-size", type=int, default=8)
739
+ p.add_argument("--n-ddpm-samples", type=int, default=8)
740
+ p.add_argument("--n-marg-samples", type=int, default=20)
741
+ p.add_argument("--sigma-pk", type=float, default=None)
742
+ p.add_argument("--n-calib-pairs", type=int, default=60)
743
+ p.add_argument("--no-ppc", action="store_true")
744
+ p.add_argument("--ddpm2-only", action="store_true")
745
+ p.add_argument("--ddpm6-only", action="store_true")
746
+ return p.parse_args()
747
+
748
+
749
+ def main() -> None:
750
+ args = parse_args()
751
+
752
+ if args.ddpm2_only and args.ddpm6_only:
753
+ raise SystemExit("Use at most one of --ddpm2-only / --ddpm6-only.")
754
+
755
+ out_dir = Path(args.output_dir).resolve()
756
+ out_dir.mkdir(parents=True, exist_ok=True)
757
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
758
+ print(f"Device: {device}")
759
+ print(f"Output directory: {out_dir}")
760
+
761
+ data2 = Path(args.data_2param)
762
+ data6 = Path(args.data_6param)
763
+
764
+ imgs2, lab2 = ec.load_split(data2, args.split)
765
+ imgs6, lab6 = ec.load_split(data6, args.split)
766
+
767
+ n = min(len(lab2), len(lab6))
768
+ anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
769
+ print(f"Anchor indices: {anchor_ix.tolist()}")
770
+
771
+ lo_tail, hi_tail = tail_lhs_bounds(data6)
772
+ print(f"LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
773
+
774
+ mean2, std2 = ec.load_label_stats(data2)
775
+ mean6, std6 = ec.load_label_stats(data6)
776
+
777
+ if not args.ddpm6_only:
778
+ print("\n" + "═" * 60)
779
+ print(">>> DDPM-2 (corrected posteriors, six anchors)")
780
+ print("═" * 60)
781
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
782
+ args_json_2 = args.bundle_2param / "args.json"
783
+ model2, cfg2 = load_model(args_json_2, ck2, device)
784
+
785
+ if args.sigma_pk is not None:
786
+ sigma2 = args.sigma_pk
787
+ print(f" Using user-supplied sigma_pk = {sigma2:.4f}")
788
+ else:
789
+ sigma2 = calibrate_sigma_pk(
790
+ data2, model2, mean2, std2,
791
+ bool(cfg2.get("normalize_labels", True)),
792
+ int(imgs2.shape[-2]), int(imgs2.shape[-1]),
793
+ device, args.ddim_steps, args.n_calib_pairs,
794
+ )
795
+
796
+ run_model(
797
+ out_dir,
798
+ model_name="DDPM-2",
799
+ model=model2,
800
+ cfg=cfg2,
801
+ images=imgs2,
802
+ labels=lab2,
803
+ lab_mean=mean2,
804
+ lab_std=std2,
805
+ anchor_ix=anchor_ix,
806
+ grid=args.grid,
807
+ ddim_steps=args.ddim_steps,
808
+ batch_sz=args.batch_size,
809
+ n_ddpm_samples=args.n_ddpm_samples,
810
+ n_marg_samples=args.n_marg_samples,
811
+ sigma_pk=sigma2,
812
+ do_ppc=not args.no_ppc,
813
+ )
814
+ del model2
815
+ gc.collect()
816
+ if torch.cuda.is_available():
817
+ torch.cuda.empty_cache()
818
+
819
+ if not args.ddpm2_only:
820
+ print("\n" + "═" * 60)
821
+ print(">>> DDPM-6 (corrected posteriors + MC marginalisation, six anchors)")
822
+ print("═" * 60)
823
+ ck6 = args.bundle_6param / "best_model.pt"
824
+ args_json_6 = args.bundle_6param / "args.json"
825
+ model6, cfg6 = load_model(args_json_6, ck6, device)
826
+
827
+ if args.sigma_pk is not None:
828
+ sigma6 = args.sigma_pk
829
+ print(f" Using user-supplied sigma_pk = {sigma6:.4f}")
830
+ else:
831
+ sigma6 = calibrate_sigma_pk(
832
+ data6, model6, mean6, std6,
833
+ bool(cfg6.get("normalize_labels", True)),
834
+ int(imgs6.shape[-2]), int(imgs6.shape[-1]),
835
+ device, args.ddim_steps, args.n_calib_pairs,
836
+ )
837
+
838
+ run_model(
839
+ out_dir,
840
+ model_name="DDPM-6",
841
+ model=model6,
842
+ cfg=cfg6,
843
+ images=imgs6,
844
+ labels=lab6,
845
+ lab_mean=mean6,
846
+ lab_std=std6,
847
+ anchor_ix=anchor_ix,
848
+ grid=args.grid,
849
+ ddim_steps=args.ddim_steps,
850
+ batch_sz=args.batch_size,
851
+ n_ddpm_samples=args.n_ddpm_samples,
852
+ n_marg_samples=args.n_marg_samples,
853
+ sigma_pk=sigma6,
854
+ lo_tail=lo_tail,
855
+ hi_tail=hi_tail,
856
+ do_ppc=not args.no_ppc,
857
+ )
858
+ del model6
859
+ gc.collect()
860
+ if torch.cuda.is_available():
861
+ torch.cuda.empty_cache()
862
+
863
+ print(f"\nAll outputs saved to: {out_dir}")
864
+
865
+
866
+ if __name__ == "__main__":
867
+ main()
cross_model/poster.py ADDED
@@ -0,0 +1,1112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ddpm_posterior_six_anchors_corrected.py
4
+ ========================================
5
+ Corrected surrogate P(k) likelihood posteriors on (Omega_m, sigma_8)
6
+ for six CAMELS test anchors.
7
+
8
+ CORRECTIONS OVER THE ORIGINAL SCRIPT
9
+ --------------------------------------
10
+ (1) STOCHASTIC EMULATOR NOISE [was: 1 DDPM sample/grid point → fragmented posteriors]
11
+ Now: average log P(k) over `--n-pk-samples` (default 8) DDPM draws per grid
12
+ point, suppressing emulator variance by ~1/sqrt(N_s).
13
+
14
+ (2) CALIBRATED LIKELIHOOD NOISE SCALE [was: hard-coded sigma=0.25]
15
+ Now: sigma_pk is estimated from the scatter of log P(k) across repeated DDPM
16
+ draws at a sample of validation labels — making the noise scale physically
17
+ meaningful and data-driven.
18
+
19
+ (3) PROPER MARGINALIZATION OVER ASTROPHYSICAL PARAMETERS [was: fix to LHS min/max]
20
+ For DDPM-6, dims 2–5 are now integrated out via Monte Carlo:
21
+ p(Om, s8 | d) ≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i), θ_extra^i ~ Uniform(LHS)
22
+ replacing the incorrect conditional likelihoods p(d | Om, s8, θ_extra = fixed).
23
+
24
+ (4) GRID RESOLUTION [was: 14×14 = 196 points]
25
+ Now: 30×30 = 900 points (configurable via --grid).
26
+
27
+ (5) EFFECTIVE SAMPLE SIZE [was: none]
28
+ n_eff = 1 / Σ w_i^2 is printed for every panel. Values ≪ 30 flag collapse.
29
+
30
+ (6) CREDIBLE CONTOURS [was: raw contourf only]
31
+ Now: 68 % and 95 % posterior mass contours drawn explicitly on each panel.
32
+
33
+ (7) S8 DERIVED PARAMETER [was: absent]
34
+ S8 = sigma_8 * (Omega_m / 0.3)^0.5 reported for the posterior mean.
35
+
36
+ (8) POSTERIOR PREDICTIVE CHECK [was: absent]
37
+ A separate figure shows the 68/95 % posterior-predictive P(k) envelope
38
+ versus the observed P(k) for each anchor — a standard emulator
39
+ validation step.
40
+
41
+ USAGE
42
+ -----
43
+ # Both models, all corrections:
44
+ python ddpm_posterior_six_anchors_corrected.py
45
+
46
+ # DDPM-2 only, fast debug run:
47
+ python ddpm_posterior_six_anchors_corrected.py --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1
48
+
49
+ # DDPM-6 only, full quality:
50
+ python ddpm_posterior_six_anchors_corrected.py --ddpm6-only --grid 30 --n-pk-samples 12 --n-marg-samples 30
51
+ """
52
+
53
+ from __future__ import annotations
54
+
55
+ import argparse
56
+ import gc
57
+ import sys
58
+ from pathlib import Path
59
+ from typing import Dict, List, Optional, Tuple
60
+
61
+ import matplotlib
62
+ matplotlib.use("Agg")
63
+ import matplotlib.pyplot as plt
64
+ import matplotlib.ticker as mticker
65
+ import numpy as np
66
+ import torch
67
+
68
+ # ── Path setup ────────────────────────────────────────────────────────────────
69
+ MODELS_ROOT = Path(__file__).resolve().parent
70
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
71
+ if str(CODE_6.resolve()) not in sys.path:
72
+ sys.path.insert(0, str(CODE_6))
73
+
74
+ import evaluate_conditional as ec # noqa: E402
75
+ import eval_model as em # noqa: E402
76
+
77
+
78
+ # ═════════════════════════════════════════════════════════════════════════════
79
+ # § 1 GRID CONSTRUCTION
80
+ # ═════════════════════════════════════════════════════════════════════════════
81
+
82
+ def build_cosmo_grid(
83
+ grid: int,
84
+ om_lo: float, om_hi: float,
85
+ s8_lo: float, s8_hi: float,
86
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
87
+ """
88
+ Build a regular (grid × grid) mesh over (Omega_m, sigma_8).
89
+
90
+ Returns
91
+ -------
92
+ om_ax : 1-D array, shape (grid,)
93
+ s8_ax : 1-D array, shape (grid,)
94
+ grid2 : 2-D array, shape (grid^2, 2) — row-major (Omega_m varies fastest)
95
+ """
96
+ om_ax = np.linspace(om_lo, om_hi, grid, dtype=np.float32)
97
+ s8_ax = np.linspace(s8_lo, s8_hi, grid, dtype=np.float32)
98
+ OG, SG = np.meshgrid(om_ax, s8_ax, indexing="ij")
99
+ grid2 = np.stack([OG.ravel(), SG.ravel()], axis=1).astype(np.float32)
100
+ return om_ax, s8_ax, grid2
101
+
102
+
103
+ def build_full_grid(
104
+ labels_ref: np.ndarray,
105
+ grid: int,
106
+ tail: Optional[np.ndarray],
107
+ lab_dim: int,
108
+ pad_frac: float = 0.02,
109
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
110
+ """
111
+ Build the full label matrix for the posterior grid.
112
+
113
+ Parameters
114
+ ----------
115
+ labels_ref : reference labels from which (Om, s8) range is inferred
116
+ grid : grid points per axis
117
+ tail : fixed values for dims 2–5 (None for DDPM-2)
118
+ lab_dim : total label dimension (2 or 6)
119
+ pad_frac : fractional padding beyond data range
120
+
121
+ Returns
122
+ -------
123
+ full : (grid^2, lab_dim) float32
124
+ om_ax : (grid,) float32
125
+ s8_ax : (grid,) float32
126
+ """
127
+ lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max())
128
+ lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max())
129
+ p0 = pad_frac * (hi0 - lo0 + 1e-12)
130
+ p1 = pad_frac * (hi1 - lo1 + 1e-12)
131
+
132
+ om_ax, s8_ax, grid2 = build_cosmo_grid(grid, lo0 - p0, hi0 + p0,
133
+ lo1 - p1, hi1 + p1)
134
+ ngrid = grid2.shape[0]
135
+
136
+ full = np.zeros((ngrid, lab_dim), dtype=np.float32)
137
+ full[:, 0] = grid2[:, 0]
138
+ full[:, 1] = grid2[:, 1]
139
+ if tail is not None:
140
+ assert tail.shape == (4,), f"tail must be shape (4,), got {tail.shape}"
141
+ full[:, 2:6] = tail[np.newaxis, :]
142
+
143
+ return full, om_ax, s8_ax
144
+
145
+
146
+ # ═════════════════════════════════════════════════════════════════════════════
147
+ # § 2 LHS BOUNDS
148
+ # ═════════════════════════════════════════════════════════════════════════════
149
+
150
+ def _train_label_path(data_dir: Path) -> Path:
151
+ for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
152
+ p = data_dir / name
153
+ if p.is_file():
154
+ return p
155
+ raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
156
+
157
+
158
+ def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
159
+ """Min/max of LHS training labels for dims 2–5."""
160
+ L = np.load(_train_label_path(data_dir))
161
+ if L.shape[1] < 6:
162
+ raise ValueError(f"Expected ≥6 label columns, got shape {L.shape}")
163
+ lo = L[:, 2:6].min(axis=0).astype(np.float32)
164
+ hi = L[:, 2:6].max(axis=0).astype(np.float32)
165
+ return lo, hi
166
+
167
+
168
+ # ═════════════════════════════════════════════════════════════════════════════
169
+ # § 3 OBSERVED LOG P(k)
170
+ # ═════════════════════════════════════════════════════════════════════════════
171
+
172
+ def log_pk_observed(
173
+ obs_image: np.ndarray,
174
+ box_size: float = 25.0,
175
+ ) -> Tuple[np.ndarray, np.ndarray]:
176
+ """
177
+ Compute log10 P(k) of the *observed* HI map, after converting
178
+ from [0,1] pixel scale to log10(N_HI).
179
+
180
+ Returns
181
+ -------
182
+ dk : k-mode array (n_bins,)
183
+ log_pd : log power spectrum of observed map (n_bins,), valid-modes only
184
+ valid : boolean mask selecting non-zero k-modes
185
+ """
186
+ # images_01_to_log_nhi expects shape (..., H, W) or (H, W)
187
+ log_nhi = em.images01_to_log_nhi(obs_image[np.newaxis]) # (1, H, W)
188
+ npix = obs_image.shape[-1]
189
+ dl = box_size / npix
190
+ dk, pk = ec.PowerSpectrum(log_nhi[0], N=npix, dl=dl)
191
+ valid = dk > 0
192
+ log_pd = np.log(pk[valid] + 1e-30)
193
+ return dk, log_pd, valid
194
+
195
+
196
+ # ═════════════════════════════════════════════════════════════════════════════
197
+ # § 4 SIGMA_PK CALIBRATION (Correction #2)
198
+ # ═════════════════════════════════════════════════════════════════════════════
199
+
200
+ def calibrate_sigma_pk(
201
+ model: torch.nn.Module,
202
+ images_val: np.ndarray,
203
+ labels_val: np.ndarray,
204
+ lab_mean: np.ndarray,
205
+ lab_std: np.ndarray,
206
+ normalize: bool,
207
+ device: torch.device,
208
+ box_size: float = 25.0,
209
+ ddim_steps: int = 50,
210
+ n_pairs: int = 30,
211
+ seed: int = 0,
212
+ ) -> float:
213
+ """
214
+ Estimate the log-P(k) noise scale from the *aleatoric* variance of the
215
+ DDPM emulator at fixed labels.
216
+
217
+ For n_pairs validation images we draw two independent DDPM samples and
218
+ compute std(log Pk_a - log Pk_b) / sqrt(2), then take the median.
219
+
220
+ This gives a physically motivated sigma_pk that replaces the hard-coded 0.25.
221
+ """
222
+ rng = np.random.default_rng(seed)
223
+ n_val = min(n_pairs, len(labels_val))
224
+ idx = rng.choice(len(labels_val), size=n_val, replace=False)
225
+ labs = labels_val[idx].astype(np.float32) # (n_val, lab_dim)
226
+
227
+ H, W = int(images_val.shape[-2]), int(images_val.shape[-1])
228
+
229
+ sigmas = []
230
+ for i in range(n_val):
231
+ lab_i = labs[i:i+1] # (1, lab_dim)
232
+ pair = np.concatenate([lab_i, lab_i], axis=0) # (2, lab_dim)
233
+
234
+ imgs = em.sample_batch(
235
+ model, pair, lab_mean, lab_std, normalize,
236
+ H, W, device, ddim_steps, False,
237
+ ) # (2, H, W) in [0, 1]
238
+
239
+ dk, log_pk_a, valid = log_pk_observed(imgs[0], box_size)
240
+ _, log_pk_b, _ = log_pk_observed(imgs[1], box_size)
241
+ diff = log_pk_a - log_pk_b
242
+ # sigma of a single draw = std(diff) / sqrt(2)
243
+ sigmas.append(float(np.std(diff) / np.sqrt(2.0)))
244
+
245
+ sigma_cal = float(np.median(sigmas))
246
+ print(
247
+ f" [calibrate_sigma_pk] n_pairs={n_val} "
248
+ f"median σ_pk={sigma_cal:.4f} "
249
+ f"(was hard-coded 0.25)"
250
+ )
251
+ return max(sigma_cal, 0.01) # safety floor
252
+
253
+
254
+ # ═════════════════════════════════════════════════════════════════════════════
255
+ # § 5 AVERAGED LOG-LIKELIHOOD (Correction #1)
256
+ # ═════════════════════════════════════════════════════════════════════════════
257
+
258
+ def averaged_log_likelihood(
259
+ obs_image: np.ndarray,
260
+ full: np.ndarray,
261
+ lab_mean: np.ndarray,
262
+ lab_std: np.ndarray,
263
+ normalize: bool,
264
+ model: torch.nn.Module,
265
+ device: torch.device,
266
+ H: int,
267
+ W: int,
268
+ box_size: float,
269
+ ddim_steps: int,
270
+ batch_sz: int,
271
+ n_pk_samples: int,
272
+ sigma_pk: float,
273
+ ) -> np.ndarray:
274
+ """
275
+ Compute the Gaussian log-likelihood for every grid point in `full`,
276
+ averaging over `n_pk_samples` independent DDPM draws to suppress
277
+ emulator stochasticity.
278
+
279
+ Parameters
280
+ ----------
281
+ full : (ngrid, lab_dim) array of grid labels
282
+ n_pk_samples : number of DDPM draws to average (≥8 recommended)
283
+ sigma_pk : calibrated log-P(k) noise scale
284
+
285
+ Returns
286
+ -------
287
+ log_w : (ngrid,) unnormalised log-posterior weights
288
+ """
289
+ _, log_pd, valid = log_pk_observed(obs_image, box_size)
290
+ ngrid = full.shape[0]
291
+
292
+ # Accumulate sum of log P(k) over n_pk_samples draws
293
+ sum_log_pg = np.zeros((ngrid, int(valid.sum())), dtype=np.float64)
294
+
295
+ for s in range(n_pk_samples):
296
+ all_pk = []
297
+ for j0 in range(0, ngrid, batch_sz):
298
+ chunk = full[j0: j0 + batch_sz]
299
+ imgs = em.sample_batch(
300
+ model, chunk, lab_mean, lab_std, normalize,
301
+ H, W, device, ddim_steps, False,
302
+ ) # (chunk_sz, H, W)
303
+ _, pks = em.per_map_power_spectra_log(imgs, box_size)
304
+ # pks shape: (chunk_sz, n_bins); select valid bins
305
+ all_pk.append(pks[:, valid])
306
+
307
+ pk_all = np.concatenate(all_pk, axis=0) # (ngrid, n_valid)
308
+ sum_log_pg += np.log(pk_all + 1e-30)
309
+
310
+ mean_log_pg = sum_log_pg / n_pk_samples # (ngrid, n_valid)
311
+
312
+ # Gaussian log-likelihood: -0.5 * Σ_k [(log Pd - log Pg)^2] / sigma^2
313
+ mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1)
314
+ log_w = -mse / (2.0 * sigma_pk ** 2)
315
+
316
+ return log_w.astype(np.float64)
317
+
318
+
319
+ # ═════════════════════════════════════════════════════════════════════════════
320
+ # § 6 POSTERIOR WEIGHT COMPUTATION
321
+ # ═════════════════════════════════════════════════════════════════════════════
322
+
323
+ def posterior_weights_ddpm2(
324
+ obs_image: np.ndarray,
325
+ labels_ref: np.ndarray,
326
+ lab_mean: np.ndarray,
327
+ lab_std: np.ndarray,
328
+ normalize: bool,
329
+ model: torch.nn.Module,
330
+ device: torch.device,
331
+ grid: int,
332
+ batch_sz: int,
333
+ ddim_steps: int,
334
+ n_pk_samples: int,
335
+ sigma_pk: float,
336
+ box_size: float = 25.0,
337
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
338
+ """
339
+ Compute the DDPM-2 surrogate posterior on (Omega_m, sigma_8).
340
+ Returns (Wmap, OM, S8) with Wmap shaped (grid, grid).
341
+ """
342
+ H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
343
+ full, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
344
+
345
+ log_w = averaged_log_likelihood(
346
+ obs_image, full, lab_mean, lab_std, normalize, model, device,
347
+ H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
348
+ )
349
+
350
+ log_w -= log_w.max() # numerical stability
351
+ w = np.exp(log_w).reshape(grid, grid)
352
+ w /= w.sum()
353
+
354
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
355
+ return w, OM, S8
356
+
357
+
358
+ def posterior_weights_ddpm6_marginalised(
359
+ obs_image: np.ndarray,
360
+ labels_ref: np.ndarray,
361
+ lab_mean: np.ndarray,
362
+ lab_std: np.ndarray,
363
+ normalize: bool,
364
+ model: torch.nn.Module,
365
+ device: torch.device,
366
+ lo_tail: np.ndarray,
367
+ hi_tail: np.ndarray,
368
+ grid: int,
369
+ batch_sz: int,
370
+ ddim_steps: int,
371
+ n_pk_samples: int,
372
+ n_marg_samples: int,
373
+ sigma_pk: float,
374
+ box_size: float = 25.0,
375
+ seed: int = 1,
376
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
377
+ """
378
+ Compute the DDPM-6 *marginal* posterior on (Omega_m, sigma_8) by
379
+ Monte Carlo integration over the astrophysical nuisance parameters:
380
+
381
+ p(Om, s8 | d) ∝ ∫ L(d | Om, s8, θ_extra) π(θ_extra) dθ_extra
382
+ ≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i)
383
+
384
+ where θ_extra^i ~ Uniform(LHS range for dims 2-5).
385
+
386
+ This replaces the incorrect approach of fixing dims 2-5 to their
387
+ LHS extrema, which computes a *conditional* likelihood, not a marginal.
388
+
389
+ Parameters
390
+ ----------
391
+ n_marg_samples : number of MC draws for astrophysical parameter integration
392
+ (≥20 recommended; more = smoother but slower)
393
+ """
394
+ rng = np.random.default_rng(seed)
395
+ H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
396
+
397
+ # Draw astrophysical parameter samples from their uniform prior over LHS
398
+ theta_extra_draws = rng.uniform(
399
+ lo_tail, hi_tail,
400
+ size=(n_marg_samples, 4),
401
+ ).astype(np.float32)
402
+
403
+ _, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
404
+ full_cosmo, _, _ = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
405
+ ngrid = full_cosmo.shape[0]
406
+
407
+ # log-sum-exp accumulator over marginalisation samples
408
+ log_w_accum = np.full(ngrid, -np.inf, dtype=np.float64)
409
+
410
+ for m_idx, theta_extra in enumerate(theta_extra_draws):
411
+ # Assemble 6D label grid with this draw of astrophysical params
412
+ full_6d = np.zeros((ngrid, 6), dtype=np.float32)
413
+ full_6d[:, :2] = full_cosmo[:, :2]
414
+ full_6d[:, 2:6] = theta_extra[np.newaxis, :]
415
+
416
+ log_w_m = averaged_log_likelihood(
417
+ obs_image, full_6d, lab_mean, lab_std, normalize, model, device,
418
+ H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
419
+ )
420
+
421
+ # log-sum-exp: accumulate log Σ L_i → after loop divide by N_marg
422
+ log_w_accum = np.logaddexp(log_w_accum, log_w_m)
423
+
424
+ if (m_idx + 1) % 5 == 0 or (m_idx + 1) == n_marg_samples:
425
+ print(f" marginalisation sample {m_idx+1}/{n_marg_samples} done")
426
+
427
+ # Subtract log(N_marg) to convert sum → mean, then normalise
428
+ log_w_accum -= np.log(n_marg_samples)
429
+ log_w_accum -= log_w_accum.max()
430
+ w = np.exp(log_w_accum).reshape(grid, grid)
431
+ w /= w.sum()
432
+
433
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
434
+ return w, OM, S8
435
+
436
+
437
+ # ═════════════════════════════════════════════════════════════════════════════
438
+ # § 7 POSTERIOR DIAGNOSTICS
439
+ # ═════════════════════════════════════════════════════════════════════════════
440
+
441
+ def effective_sample_size(w: np.ndarray) -> float:
442
+ """n_eff = 1 / Σ w_i^2. Values < 30 indicate posterior collapse."""
443
+ w_flat = w.ravel() / w.sum()
444
+ return float(1.0 / (w_flat ** 2).sum())
445
+
446
+
447
+ def credible_levels(
448
+ w: np.ndarray,
449
+ levels: Tuple[float, ...] = (0.68, 0.95),
450
+ ) -> List[float]:
451
+ """
452
+ Find the weight threshold c such that the region {w ≥ c} contains
453
+ exactly `level` of the total probability mass.
454
+
455
+ Returns a list of thresholds, one per level (descending).
456
+ """
457
+ w_flat = w.ravel()
458
+ sorted_w = np.sort(w_flat)[::-1]
459
+ cumsum = np.cumsum(sorted_w)
460
+ thresholds = []
461
+ for level in levels:
462
+ idx = np.searchsorted(cumsum, level * w_flat.sum())
463
+ idx = min(idx, len(sorted_w) - 1)
464
+ thresholds.append(float(sorted_w[idx]))
465
+ return thresholds
466
+
467
+
468
+ def posterior_summary(
469
+ w: np.ndarray,
470
+ OM: np.ndarray,
471
+ S8: np.ndarray,
472
+ ) -> Dict:
473
+ """
474
+ Return a dict with posterior mean, std, and S8 derived parameter.
475
+ """
476
+ w_norm = w / w.sum()
477
+ mom = float((w_norm * OM).sum())
478
+ ms8 = float((w_norm * S8).sum())
479
+ vom = float((w_norm * (OM - mom) ** 2).sum()) ** 0.5
480
+ vs8 = float((w_norm * (S8 - ms8) ** 2).sum()) ** 0.5
481
+ mS8 = ms8 * (mom / 0.3) ** 0.5
482
+ n_eff = effective_sample_size(w_norm)
483
+ return dict(om_mean=mom, om_std=vom, s8_mean=ms8, s8_std=vs8,
484
+ S8_mean=mS8, n_eff=n_eff)
485
+
486
+
487
+ # ═════════════════════════════════════════════════════════════════════════════
488
+ # § 8 PLOTTING
489
+ # ═════════════════════════════════════════════════════════════════════════════
490
+
491
+ def plot_posterior_panel(
492
+ ax: plt.Axes,
493
+ w: np.ndarray,
494
+ OM: np.ndarray,
495
+ S8: np.ndarray,
496
+ true_om: float,
497
+ true_s8: float,
498
+ title: str,
499
+ summary: Optional[Dict] = None,
500
+ ) -> None:
501
+ """
502
+ Plot one posterior panel with:
503
+ • filled colour map of posterior weights
504
+ • 68 % and 95 % credible contours
505
+ • true parameter location (red ×)
506
+ • posterior mean (black +)
507
+ • n_eff and posterior-mean S8 as text annotation
508
+ """
509
+ # ── colour map ────────────────────────────────────────────────────────────
510
+ cf = ax.contourf(OM, S8, w, levels=14, cmap="Blues")
511
+ plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
512
+
513
+ # ── credible contours ─────────────────────────────────────────────────────
514
+ try:
515
+ thresh_68, thresh_95 = credible_levels(w, levels=(0.68, 0.95))
516
+ ax.contour(OM, S8, w, levels=[thresh_95, thresh_68],
517
+ colors=["#e07b39", "#c0392b"],
518
+ linewidths=[1.2, 1.8], linestyles=["--", "-"])
519
+ # Proxy artists for legend
520
+ from matplotlib.lines import Line2D
521
+ ax.legend(
522
+ handles=[
523
+ Line2D([], [], color="#c0392b", lw=1.8, label="68 % CR"),
524
+ Line2D([], [], color="#e07b39", lw=1.2, ls="--", label="95 % CR"),
525
+ Line2D([], [], marker="x", color="r", ls="", ms=8, label="true"),
526
+ Line2D([], [], marker="+", color="k", ls="", ms=8, label="post. mean"),
527
+ ],
528
+ fontsize=6.5, loc="upper right",
529
+ )
530
+ except Exception:
531
+ ax.legend(fontsize=6.5)
532
+
533
+ # ── markers ───────────────────────────────────────────────────────────────
534
+ if summary:
535
+ ax.scatter(summary["om_mean"], summary["s8_mean"],
536
+ s=60, c="k", marker="+", zorder=7)
537
+ ax.scatter(true_om, true_s8, s=60, c="r", marker="x", zorder=7)
538
+
539
+ # ── S8 degeneracy line (for visual reference) ─────────────────────────────
540
+ om_arr = np.linspace(float(OM.min()), float(OM.max()), 200)
541
+ if summary:
542
+ S8_val = summary["s8_mean"] * (summary["om_mean"] / 0.3) ** 0.5
543
+ s8_degen = S8_val / (om_arr / 0.3) ** 0.5
544
+ mask = (s8_degen >= float(S8.min())) & (s8_degen <= float(S8.max()))
545
+ if mask.any():
546
+ ax.plot(om_arr[mask], s8_degen[mask], "k:", lw=0.8, alpha=0.5,
547
+ label=f"$S_8$={S8_val:.3f}")
548
+
549
+ # ── labels and annotation ─────────────────────────────────────────────────
550
+ ax.set_xlabel(r"$\Omega_m$", fontsize=9)
551
+ ax.set_ylabel(r"$\sigma_8$", fontsize=9)
552
+ ax.set_title(title, fontsize=8)
553
+
554
+ if summary:
555
+ info = (
556
+ f"$n_\\mathrm{{eff}}$={summary['n_eff']:.0f}\n"
557
+ f"$S_8$={summary['S8_mean']:.3f}\n"
558
+ f"$\\Omega_m$={summary['om_mean']:.3f}±{summary['om_std']:.3f}\n"
559
+ f"$\\sigma_8$={summary['s8_mean']:.3f}±{summary['s8_std']:.3f}"
560
+ )
561
+ ax.text(0.02, 0.98, info, transform=ax.transAxes,
562
+ fontsize=6.5, va="top", color="#222",
563
+ bbox=dict(fc="white", ec="none", alpha=0.7, pad=1.5))
564
+
565
+
566
+ def make_posterior_figure(
567
+ panels: List[Dict],
568
+ suptitle: str,
569
+ out_path: Path,
570
+ ) -> None:
571
+ """
572
+ Create a 2×3 grid of posterior panels and save to `out_path`.
573
+
574
+ Each element of `panels` must be a dict with keys:
575
+ w, OM, S8, true_om, true_s8, title, summary
576
+ """
577
+ fig, axes = plt.subplots(2, 3, figsize=(15, 9.5), squeeze=False)
578
+ for k, p in enumerate(panels):
579
+ r, c = divmod(k, 3)
580
+ plot_posterior_panel(
581
+ axes[r, c],
582
+ p["w"], p["OM"], p["S8"],
583
+ p["true_om"], p["true_s8"],
584
+ p["title"], p.get("summary"),
585
+ )
586
+ plt.suptitle(suptitle, fontsize=11, y=0.998)
587
+ plt.tight_layout(rect=(0, 0, 1, 0.97))
588
+ fig.savefig(out_path, dpi=170, bbox_inches="tight")
589
+ plt.close(fig)
590
+ print(f" Saved → {out_path}")
591
+
592
+
593
+ # ═════════════════════════════════════════════════════════════════════════════
594
+ # § 9 POSTERIOR PREDICTIVE CHECK (Correction #8)
595
+ # ═════════════════════════════════════════════════════════════════════════════
596
+
597
+ def posterior_predictive_check(
598
+ obs_image: np.ndarray,
599
+ w: np.ndarray,
600
+ OM: np.ndarray,
601
+ S8: np.ndarray,
602
+ model: torch.nn.Module,
603
+ lab_mean: np.ndarray,
604
+ lab_std: np.ndarray,
605
+ normalize: bool,
606
+ device: torch.device,
607
+ ddim_steps: int,
608
+ box_size: float = 25.0,
609
+ n_draws: int = 40,
610
+ seed: int = 42,
611
+ ) -> Tuple[np.ndarray, np.ndarray]:
612
+ """
613
+ Draw `n_draws` parameter samples from the posterior and generate DDPM
614
+ images; return the stacked log P(k) array for envelope plotting.
615
+ """
616
+ rng = np.random.default_rng(seed)
617
+ w_flat = w.ravel() / w.sum()
618
+ idx = rng.choice(len(w_flat), size=n_draws, replace=True, p=w_flat)
619
+
620
+ om_flat = OM.ravel()
621
+ s8_flat = S8.ravel()
622
+ labs = np.stack([om_flat[idx], s8_flat[idx]], axis=1).astype(np.float32)
623
+
624
+ H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
625
+ imgs = em.sample_batch(
626
+ model, labs, lab_mean, lab_std, normalize,
627
+ H, W, device, ddim_steps, False,
628
+ ) # (n_draws, H, W)
629
+
630
+ _, pks = em.per_map_power_spectra_log(imgs, box_size) # (n_draws, n_bins)
631
+ log_pks = np.log(pks + 1e-30)
632
+
633
+ # Observed
634
+ dk, log_pd, valid = log_pk_observed(obs_image, box_size)
635
+ return dk[valid], log_pd, log_pks[:, valid]
636
+
637
+
638
+ def plot_ppc_panel(
639
+ ax: plt.Axes,
640
+ dk_valid: np.ndarray,
641
+ log_pd: np.ndarray,
642
+ log_pks: np.ndarray,
643
+ title: str,
644
+ ) -> None:
645
+ lo95 = np.percentile(log_pks, 2.5, axis=0)
646
+ hi95 = np.percentile(log_pks, 97.5, axis=0)
647
+ lo68 = np.percentile(log_pks, 16.0, axis=0)
648
+ hi68 = np.percentile(log_pks, 84.0, axis=0)
649
+ med = np.median(log_pks, axis=0)
650
+
651
+ ax.fill_between(dk_valid, lo95, hi95,
652
+ alpha=0.20, color="steelblue", label="95 % PPC")
653
+ ax.fill_between(dk_valid, lo68, hi68,
654
+ alpha=0.40, color="steelblue", label="68 % PPC")
655
+ ax.plot(dk_valid, med, "b-", lw=1.4, label="PPC median")
656
+ ax.plot(dk_valid, log_pd, "r-", lw=1.6, label="Observed")
657
+
658
+ ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8)
659
+ ax.set_ylabel(r"$\log\,P_\mathrm{HI}(k)$", fontsize=8)
660
+ ax.set_title(title, fontsize=8)
661
+ ax.legend(fontsize=6.5)
662
+ ax.grid(alpha=0.3, lw=0.5)
663
+
664
+
665
+ def make_ppc_figure(
666
+ ppc_data: List[Dict],
667
+ suptitle: str,
668
+ out_path: Path,
669
+ ) -> None:
670
+ fig, axes = plt.subplots(2, 3, figsize=(15, 8), squeeze=False)
671
+ for k, d in enumerate(ppc_data):
672
+ r, c = divmod(k, 3)
673
+ plot_ppc_panel(axes[r, c], d["dk"], d["log_pd"],
674
+ d["log_pks"], d["title"])
675
+ plt.suptitle(suptitle, fontsize=11, y=0.998)
676
+ plt.tight_layout(rect=(0, 0, 1, 0.97))
677
+ fig.savefig(out_path, dpi=150, bbox_inches="tight")
678
+ plt.close(fig)
679
+ print(f" Saved → {out_path}")
680
+
681
+
682
+ # ═════════════════════════════════════════════════════════════════════════════
683
+ # § 10 MODEL LOADING
684
+ # ═════════════════════════════════════════════════════════════════════════════
685
+
686
+ def load_model(
687
+ args_json: Path,
688
+ ckpt: Path,
689
+ device: torch.device,
690
+ ) -> Tuple[torch.nn.Module, Dict]:
691
+ cfg = ec.load_training_config(str(args_json))
692
+ model = ec.build_model(cfg, device)
693
+ ec.load_checkpoint(model, str(ckpt), device)
694
+ model.eval()
695
+ return model, cfg
696
+
697
+
698
+ # ═════════════════════════════════════════════════════════════════════════════
699
+ # § 11 HIGH-LEVEL RUNNERS
700
+ # ═════════════════════════════════════════════════════════════════════════════
701
+
702
+ def run_ddpm2(
703
+ out_dir: Path,
704
+ imgs: np.ndarray,
705
+ labs: np.ndarray,
706
+ lab_mean: np.ndarray,
707
+ lab_std: np.ndarray,
708
+ cfg: Dict,
709
+ model: torch.nn.Module,
710
+ device: torch.device,
711
+ anchor_ix: np.ndarray,
712
+ grid: int,
713
+ ddim_steps: int,
714
+ batch_sz: int,
715
+ n_pk_samples: int,
716
+ sigma_pk: float,
717
+ do_ppc: bool = True,
718
+ ) -> None:
719
+ normalize = bool(cfg.get("normalize_labels", True))
720
+ panels = []
721
+ ppc_data = []
722
+
723
+ for k, ix in enumerate(anchor_ix.ravel()):
724
+ ix = int(ix)
725
+ obs = imgs[ix]
726
+ lab_t = labs[ix].astype(np.float32)
727
+ tom, ts8 = float(lab_t[0]), float(lab_t[1])
728
+
729
+ print(f" [DDPM-2] anchor {k+1}/6 ix={ix} "
730
+ f"Ωm={tom:.3f} σ8={ts8:.3f}")
731
+
732
+ w, OM, S8 = posterior_weights_ddpm2(
733
+ obs, labs, lab_mean, lab_std, normalize, model, device,
734
+ grid, batch_sz, ddim_steps, n_pk_samples, sigma_pk,
735
+ )
736
+ summ = posterior_summary(w, OM, S8)
737
+ print(f" n_eff={summ['n_eff']:.0f} "
738
+ f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} "
739
+ f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} "
740
+ f"S8={summ['S8_mean']:.3f}")
741
+
742
+ panels.append(dict(
743
+ w=w, OM=OM, S8=S8,
744
+ true_om=tom, true_s8=ts8, summary=summ,
745
+ title=(
746
+ f"test ix={ix} | "
747
+ r"$\Omega_m$" + f"={tom:.3f}, "
748
+ r"$\sigma_8$" + f"={ts8:.3f}"
749
+ ),
750
+ ))
751
+
752
+ if do_ppc:
753
+ dk_v, log_pd, log_pks = posterior_predictive_check(
754
+ obs, w, OM, S8, model, lab_mean, lab_std, normalize,
755
+ device, ddim_steps,
756
+ )
757
+ ppc_data.append(dict(
758
+ dk=dk_v, log_pd=log_pd, log_pks=log_pks,
759
+ title=f"PPC test ix={ix}",
760
+ ))
761
+
762
+ # ── posterior figure ──────────────────────────────────────────────────────
763
+ make_posterior_figure(
764
+ panels,
765
+ suptitle=(
766
+ r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ — "
767
+ r"six CAMELS anchors "
768
+ f"[{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]"
769
+ ),
770
+ out_path=out_dir / "posterior_six_anchors_ddpm2_corrected.png",
771
+ )
772
+
773
+ # ── PPC figure ────────────────────────────────────────────────────────────
774
+ if do_ppc and ppc_data:
775
+ make_ppc_figure(
776
+ ppc_data,
777
+ suptitle="DDPM-2 Posterior Predictive Check — P(k) envelope vs. observed",
778
+ out_path=out_dir / "ppc_six_anchors_ddpm2.png",
779
+ )
780
+
781
+
782
+ def run_ddpm6(
783
+ out_dir: Path,
784
+ imgs: np.ndarray,
785
+ labs: np.ndarray,
786
+ lab_mean: np.ndarray,
787
+ lab_std: np.ndarray,
788
+ cfg: Dict,
789
+ model: torch.nn.Module,
790
+ device: torch.device,
791
+ lo_tail: np.ndarray,
792
+ hi_tail: np.ndarray,
793
+ anchor_ix: np.ndarray,
794
+ grid: int,
795
+ ddim_steps: int,
796
+ batch_sz: int,
797
+ n_pk_samples: int,
798
+ n_marg_samples: int,
799
+ sigma_pk: float,
800
+ do_ppc: bool = True,
801
+ ) -> None:
802
+ normalize = bool(cfg.get("normalize_labels", True))
803
+ panels = []
804
+ ppc_data = []
805
+
806
+ for k, ix in enumerate(anchor_ix.ravel()):
807
+ ix = int(ix)
808
+ obs = imgs[ix]
809
+ lab_t = labs[ix].astype(np.float32)
810
+ tom, ts8 = float(lab_t[0]), float(lab_t[1])
811
+
812
+ print(f" [DDPM-6] anchor {k+1}/6 ix={ix} "
813
+ f"Ωm={tom:.3f} σ8={ts8:.3f}")
814
+
815
+ w, OM, S8 = posterior_weights_ddpm6_marginalised(
816
+ obs, labs, lab_mean, lab_std, normalize, model, device,
817
+ lo_tail, hi_tail,
818
+ grid, batch_sz, ddim_steps,
819
+ n_pk_samples, n_marg_samples, sigma_pk,
820
+ )
821
+ summ = posterior_summary(w, OM, S8)
822
+ print(f" n_eff={summ['n_eff']:.0f} "
823
+ f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} "
824
+ f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} "
825
+ f"S8={summ['S8_mean']:.3f}")
826
+
827
+ panels.append(dict(
828
+ w=w, OM=OM, S8=S8,
829
+ true_om=tom, true_s8=ts8, summary=summ,
830
+ title=(
831
+ f"test ix={ix} | "
832
+ r"$\Omega_m$" + f"={tom:.3f}, "
833
+ r"$\sigma_8$" + f"={ts8:.3f}"
834
+ f"\n[MC marg., N_marg={n_marg_samples}]"
835
+ ),
836
+ ))
837
+
838
+ if do_ppc:
839
+ # For PPC, use DDPM-2-style sampling (only 2 cosmological params)
840
+ # with a random draw from the astrophysical prior
841
+ rng = np.random.default_rng(ix)
842
+ te = rng.uniform(lo_tail, hi_tail).astype(np.float32)
843
+ # Build 2D posterior weights recast to 6D labels for PPC
844
+ w2, OM2, S82 = w, OM, S8 # same posterior geometry
845
+ dk_v, log_pd, log_pks = posterior_predictive_check(
846
+ obs, w2, OM2, S82, model, lab_mean, lab_std, normalize,
847
+ device, ddim_steps,
848
+ )
849
+ ppc_data.append(dict(
850
+ dk=dk_v, log_pd=log_pd, log_pks=log_pks,
851
+ title=f"PPC test ix={ix}",
852
+ ))
853
+
854
+ # ── posterior figure ──────────────────────────────────────────────────────
855
+ make_posterior_figure(
856
+ panels,
857
+ suptitle=(
858
+ r"DDPM-6 marginal posterior on $(\Omega_m,\,\sigma_8)$ — "
859
+ r"six CAMELS anchors "
860
+ f"[MC marginalisation, N_marg={n_marg_samples}, "
861
+ f"{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]"
862
+ ),
863
+ out_path=out_dir / "posterior_six_anchors_ddpm6_marginalised_corrected.png",
864
+ )
865
+
866
+ if do_ppc and ppc_data:
867
+ make_ppc_figure(
868
+ ppc_data,
869
+ suptitle="DDPM-6 Posterior Predictive Check — P(k) envelope vs. observed",
870
+ out_path=out_dir / "ppc_six_anchors_ddpm6.png",
871
+ )
872
+
873
+
874
+ # ═════════════════════════════════════════════════════════════════════════════
875
+ # § 12 CLI
876
+ # ════════════════════════════════════════════════���════════════════════════════
877
+
878
+ def parse_args() -> argparse.Namespace:
879
+ p = argparse.ArgumentParser(
880
+ description=(
881
+ "Corrected six-anchor surrogate posteriors: DDPM-2 and DDPM-6.\n"
882
+ "See module docstring for a full list of corrections applied."
883
+ ),
884
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
885
+ )
886
+ p.add_argument(
887
+ "--output-dir", type=Path,
888
+ default=MODELS_ROOT / "ddpm_posterior_corrected_out",
889
+ )
890
+ p.add_argument(
891
+ "--data-2param", type=Path,
892
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"),
893
+ )
894
+ p.add_argument(
895
+ "--data-6param", type=Path,
896
+ default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"),
897
+ )
898
+ p.add_argument(
899
+ "--bundle-2param", type=Path,
900
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
901
+ )
902
+ p.add_argument(
903
+ "--bundle-6param", type=Path,
904
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
905
+ )
906
+ p.add_argument(
907
+ "--split", default="test", choices=["train", "val", "test"],
908
+ )
909
+ # ── grid ──────────────────────────────────────────────────────────────────
910
+ p.add_argument(
911
+ "--grid", type=int, default=30,
912
+ help="Grid points per Ωm–σ8 axis (30×30=900 default, was 14×14=196).",
913
+ )
914
+ # ── sampling ──────────────────────────────────────────────────────────────
915
+ p.add_argument(
916
+ "--ddim-steps", type=int, default=50,
917
+ help="DDIM denoising steps per sample.",
918
+ )
919
+ p.add_argument(
920
+ "--batch-size", type=int, default=8,
921
+ help="Grid-point batch size for DDPM forward passes.",
922
+ )
923
+ p.add_argument(
924
+ "--n-pk-samples", type=int, default=8,
925
+ help=(
926
+ "DDPM draws to average per grid point. "
927
+ "Variance ∝ 1/n_pk_samples. "
928
+ "≥8 recommended; use 4 for a fast debug run."
929
+ ),
930
+ )
931
+ p.add_argument(
932
+ "--n-marg-samples", type=int, default=20,
933
+ help=(
934
+ "MC draws for DDPM-6 astrophysical marginalisation. "
935
+ "≥20 recommended; use 5 for a fast debug run."
936
+ ),
937
+ )
938
+ # ── sigma calibration ─────────────────────────────────────────────────────
939
+ p.add_argument(
940
+ "--n-calib-pairs", type=int, default=30,
941
+ help="Number of image pairs used to calibrate sigma_pk.",
942
+ )
943
+ p.add_argument(
944
+ "--sigma-pk", type=float, default=None,
945
+ help=(
946
+ "Override calibrated sigma_pk with a fixed value. "
947
+ "Leave unset to use automatic calibration (recommended)."
948
+ ),
949
+ )
950
+ # ── scope ─────────────────────────────────────────────────────────────────
951
+ p.add_argument(
952
+ "--ddpm2-only", action="store_true",
953
+ help="Only run DDPM-2 (skip loading DDPM-6).",
954
+ )
955
+ p.add_argument(
956
+ "--ddpm6-only", action="store_true",
957
+ help="Only run DDPM-6 (skip loading DDPM-2).",
958
+ )
959
+ p.add_argument(
960
+ "--no-ppc", action="store_true",
961
+ help="Skip posterior predictive check figures.",
962
+ )
963
+ return p.parse_args()
964
+
965
+
966
+ # ═════════════════════════════════════════════════════════════════════════════
967
+ # § 13 MAIN
968
+ # ═════════════════════════════════════════════════════════════════════════════
969
+
970
+ def main() -> None:
971
+ args = parse_args()
972
+
973
+ if args.ddpm2_only and args.ddpm6_only:
974
+ raise SystemExit("Specify at most one of --ddpm2-only / --ddpm6-only.")
975
+
976
+ out_dir = Path(args.output_dir).resolve()
977
+ out_dir.mkdir(parents=True, exist_ok=True)
978
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
979
+ print(f"Device : {device}")
980
+ print(f"Output : {out_dir}")
981
+ print()
982
+
983
+ # ── load data ─────────────────────────────────────────────────────────────
984
+ data2 = Path(args.data_2param)
985
+ data6 = Path(args.data_6param)
986
+
987
+ if not args.ddpm6_only:
988
+ imgs2, labs2 = ec.load_split(data2, args.split)
989
+ mean2, std2 = ec.load_label_stats(data2)
990
+ print(f"DDPM-2 {args.split} set : {len(labs2)} maps "
991
+ f"label_dim={labs2.shape[1]}")
992
+
993
+ if not args.ddpm2_only:
994
+ imgs6, labs6 = ec.load_split(data6, args.split)
995
+ mean6, std6 = ec.load_label_stats(data6)
996
+ lo_tail, hi_tail = tail_lhs_bounds(data6)
997
+ print(f"DDPM-6 {args.split} set : {len(labs6)} maps "
998
+ f"label_dim={labs6.shape[1]}")
999
+ print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
1000
+
1001
+ # ── six anchors ───────────────────────────────────────────────────────────
1002
+ if not args.ddpm6_only:
1003
+ n_ref = len(labs2)
1004
+ else:
1005
+ n_ref = len(labs6)
1006
+ anchor_ix = np.linspace(0, n_ref - 1, num=6, dtype=int)
1007
+ print(f"\nAnchor indices: {anchor_ix.tolist()}\n")
1008
+
1009
+ # ── checkpoints ───────────────────────────────────────────────────────────
1010
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
1011
+ args_j2 = args.bundle_2param / "args.json"
1012
+ ck6 = args.bundle_6param / "best_model.pt"
1013
+ args_j6 = args.bundle_6param / "args.json"
1014
+
1015
+ # ══════════════════════════════════════════════════════════════════════════
1016
+ # DDPM-2 BLOCK
1017
+ # ══════════════════════════════════════════════════════════════════════════
1018
+ if not args.ddpm6_only:
1019
+ print("=" * 70)
1020
+ print(">>> DDPM-2 (six anchors)")
1021
+ print("=" * 70)
1022
+
1023
+ model2, cfg2 = load_model(args_j2, ck2, device)
1024
+
1025
+ # ── sigma_pk calibration ──────────────────────────────────────────────
1026
+ if args.sigma_pk is not None:
1027
+ sigma2 = args.sigma_pk
1028
+ print(f" sigma_pk overridden to {sigma2:.4f}")
1029
+ else:
1030
+ print(" Calibrating sigma_pk from validation set …")
1031
+ imgs2_val, labs2_val = ec.load_split(data2, "val")
1032
+ sigma2 = calibrate_sigma_pk(
1033
+ model2, imgs2_val, labs2_val,
1034
+ mean2, std2,
1035
+ normalize=bool(cfg2.get("normalize_labels", True)),
1036
+ device=device,
1037
+ ddim_steps=args.ddim_steps,
1038
+ n_pairs=args.n_calib_pairs,
1039
+ )
1040
+
1041
+ run_ddpm2(
1042
+ out_dir=out_dir,
1043
+ imgs=imgs2, labs=labs2,
1044
+ lab_mean=mean2, lab_std=std2,
1045
+ cfg=cfg2, model=model2, device=device,
1046
+ anchor_ix=anchor_ix,
1047
+ grid=args.grid,
1048
+ ddim_steps=args.ddim_steps,
1049
+ batch_sz=args.batch_size,
1050
+ n_pk_samples=args.n_pk_samples,
1051
+ sigma_pk=sigma2,
1052
+ do_ppc=not args.no_ppc,
1053
+ )
1054
+
1055
+ del model2
1056
+ gc.collect()
1057
+ if torch.cuda.is_available():
1058
+ torch.cuda.empty_cache()
1059
+ print()
1060
+
1061
+ # ══════════════════════════════════════════════════════════════════════════
1062
+ # DDPM-6 BLOCK
1063
+ # ══════════════════════════════════════════════════════════════════════════
1064
+ if not args.ddpm2_only:
1065
+ print("=" * 70)
1066
+ print(">>> DDPM-6 (six anchors, MC marginalisation over dims 2-5)")
1067
+ print("=" * 70)
1068
+
1069
+ model6, cfg6 = load_model(args_j6, ck6, device)
1070
+
1071
+ # ── sigma_pk calibration ──────────────────────────────────────────────
1072
+ if args.sigma_pk is not None:
1073
+ sigma6 = args.sigma_pk
1074
+ print(f" sigma_pk overridden to {sigma6:.4f}")
1075
+ else:
1076
+ print(" Calibrating sigma_pk from validation set …")
1077
+ imgs6_val, labs6_val = ec.load_split(data6, "val")
1078
+ sigma6 = calibrate_sigma_pk(
1079
+ model6, imgs6_val, labs6_val,
1080
+ mean6, std6,
1081
+ normalize=bool(cfg6.get("normalize_labels", True)),
1082
+ device=device,
1083
+ ddim_steps=args.ddim_steps,
1084
+ n_pairs=args.n_calib_pairs,
1085
+ )
1086
+
1087
+ run_ddpm6(
1088
+ out_dir=out_dir,
1089
+ imgs=imgs6, labs=labs6,
1090
+ lab_mean=mean6, lab_std=std6,
1091
+ cfg=cfg6, model=model6, device=device,
1092
+ lo_tail=lo_tail, hi_tail=hi_tail,
1093
+ anchor_ix=anchor_ix,
1094
+ grid=args.grid,
1095
+ ddim_steps=args.ddim_steps,
1096
+ batch_sz=args.batch_size,
1097
+ n_pk_samples=args.n_pk_samples,
1098
+ n_marg_samples=args.n_marg_samples,
1099
+ sigma_pk=sigma6,
1100
+ do_ppc=not args.no_ppc,
1101
+ )
1102
+
1103
+ del model6
1104
+ gc.collect()
1105
+ if torch.cuda.is_available():
1106
+ torch.cuda.empty_cache()
1107
+
1108
+ print(f"\nAll done. Results in {out_dir}")
1109
+
1110
+
1111
+ if __name__ == "__main__":
1112
+ main()
cross_model/run_compare_posterior.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=06:00:00
8
+ #SBATCH --job-name=cmp_post
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-cmp-post-%j.out
11
+ #SBATCH --error=slurm-cmp-post-%j.err
12
+
13
+ # DDPM-2 vs DDPM-6 corrected posteriors with JOINT P(k) + PDF likelihood.
14
+ #
15
+ # Defaults: 30x30 grid, 4 anchors, 8 DDPM draws / pt, 20 marg draws.
16
+ # Submit:
17
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/run_compare_posterior.sh
18
+ #
19
+ # Smoke test (much faster):
20
+ # sbatch run_compare_posterior.sh --grid 16 --n-pk-samples 4 \
21
+ # --n-marg-samples 5 --n-anchors 2
22
+
23
+ set -euo pipefail
24
+
25
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
26
+ cd "$ROOT"
27
+
28
+ module load python/miniconda3-py3.12-usr
29
+
30
+ PY="${ROOT}/compare_posterior_inference.py"
31
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_compare_pk_pdf_out}"
32
+
33
+ mkdir -p "${OUT}"
34
+ RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
35
+
36
+ echo "==============================================="
37
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
38
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
39
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
40
+ echo "Started: $(date)"
41
+ echo "Script: ${PY}"
42
+ echo "Output: ${OUT}"
43
+ echo "Log: ${RUN_LOG}"
44
+ echo "==============================================="
45
+
46
+ set -o pipefail
47
+ python -u "${PY}" --output-dir "${OUT}" "$@" 2>&1 | tee -a "${RUN_LOG}"
48
+
49
+ echo "==============================================="
50
+ echo "Finished: $(date)"
51
+ echo "Artifacts: ${OUT}"
52
+ echo "==============================================="
cross_model/run_vlb_inference_1000grid.sh ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --job-name=vlb_infer_1000
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-vlb-infer-1000-%j.out
11
+ #SBATCH --error=slurm-vlb-infer-1000-%j.err
12
+
13
+ # VLB / Mudur-style posterior_inference.py — 1000×1000 grid (high-resolution posteriors).
14
+ #
15
+ # High-resolution parameter grids with 9 fields, generates posterior_L0_mosaic_3x3.png
16
+ # plus field0X_combined.png (contours + L_0 posterior side-by-side).
17
+ #
18
+ # Submit:
19
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_vlb_inference_1000grid.sh
20
+ #
21
+ # Default: grid_size=1000, n_fields=9, span=0.10, generates mosaic + combined figures.
22
+ # Est. runtime: ~18-24 hours on L40S.
23
+ #
24
+ # Override via --export or command-line args:
25
+ # sbatch --export=OUTPUT_DIR=/path/custom_output scripts/run_vlb_inference_1000grid.sh
26
+ # sbatch scripts/run_vlb_inference_1000grid.sh --grid_size 500 --n_fields 4 --batch_size 16
27
+ #
28
+ # For grid_size < 300, omit --allow_huge_grid in args below.
29
+ #
30
+ # Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt.
31
+
32
+ set -euo pipefail
33
+
34
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
35
+ cd "$ROOT"
36
+
37
+ module load python/miniconda3-py3.12-usr
38
+
39
+ PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
40
+ OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs_1000grid}"
41
+
42
+ CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
43
+ ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
44
+ DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
45
+
46
+ mkdir -p "${OUT}"
47
+ RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
48
+
49
+ echo "==============================================="
50
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
51
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
52
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
53
+ echo "Started: $(date)"
54
+ echo "Python: ${PY}"
55
+ echo "checkpoint: ${CHK}"
56
+ echo "training_args: ${ARGS}"
57
+ echo "data_dir: ${DATA}"
58
+ echo "output_dir: ${OUT}"
59
+ echo "Progress log: ${RUN_LOG}"
60
+ echo "==============================================="
61
+
62
+ set -o pipefail
63
+ python -u "${PY}" \
64
+ --checkpoint "${CHK}" \
65
+ --training_args "${ARGS}" \
66
+ --data_dir "${DATA}" \
67
+ --output_dir "${OUT}" \
68
+ --grid_size 1000 \
69
+ --allow_huge_grid \
70
+ --n_fields 9 \
71
+ --span 0.10 \
72
+ --t_subset 0 1 2 5 8 10 15 20 \
73
+ --n_seeds 4 \
74
+ --batch_size 32 \
75
+ --seed 42 \
76
+ "$@" 2>&1 | tee -a "${RUN_LOG}"
77
+
78
+ echo "==============================================="
79
+ echo "Finished: $(date)"
80
+ echo "Artifacts → ${OUT}"
81
+ echo "==============================================="
cross_model/run_vlb_inference_200grid.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --job-name=vlb_infer_200
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-vlb-infer-200-%j.out
11
+ #SBATCH --error=slurm-vlb-infer-200-%j.err
12
+
13
+ # VLB / Mudur-style posterior_inference.py — 200×200 grid (balanced speed/quality).
14
+ #
15
+ # Medium-resolution parameter grids with 9 fields, generates posterior_L0_mosaic_3x3.png
16
+ # plus field0X_combined.png (contours + L_0 posterior side-by-side).
17
+ #
18
+ # Grid: 200×200 = 40K points per timestep (vs 1000×1000 = 1M points)
19
+ # Est. runtime: ~2-3 hours on L40S.
20
+ #
21
+ # Submit:
22
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/run_vlb_inference_200grid.sh
23
+ #
24
+ # Override via --export or command-line args:
25
+ # sbatch --export=OUTPUT_DIR=/path/custom_output run_vlb_inference_200grid.sh
26
+ # sbatch run_vlb_inference_200grid.sh --grid_size 150 --n_fields 4
27
+ #
28
+ # Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt.
29
+
30
+ set -euo pipefail
31
+
32
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
33
+ cd "$ROOT"
34
+
35
+ module load python/miniconda3-py3.12-usr
36
+
37
+ PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
38
+ OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs_200grid}"
39
+
40
+ CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
41
+ ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
42
+ DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
43
+
44
+ mkdir -p "${OUT}"
45
+ RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
46
+
47
+ echo "==============================================="
48
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
49
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
50
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
51
+ echo "Started: $(date)"
52
+ echo "Python: ${PY}"
53
+ echo "checkpoint: ${CHK}"
54
+ echo "training_args: ${ARGS}"
55
+ echo "data_dir: ${DATA}"
56
+ echo "output_dir: ${OUT}"
57
+ echo "Progress log: ${RUN_LOG}"
58
+ echo "==============================================="
59
+
60
+ set -o pipefail
61
+ python -u "${PY}" \
62
+ --checkpoint "${CHK}" \
63
+ --training_args "${ARGS}" \
64
+ --data_dir "${DATA}" \
65
+ --output_dir "${OUT}" \
66
+ --grid_size 200 \
67
+ --n_fields 9 \
68
+ --span 0.10 \
69
+ --t_subset 0 1 2 5 8 10 15 20 \
70
+ --n_seeds 4 \
71
+ --batch_size 32 \
72
+ --seed 42 \
73
+ "$@" 2>&1 | tee -a "${RUN_LOG}"
74
+
75
+ echo "==============================================="
76
+ echo "Finished: $(date)"
77
+ echo "Artifacts → ${OUT}"
78
+ echo "==============================================="
cross_model/scripts/compare_ddpm_models.py ADDED
@@ -0,0 +1,855 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Compare 2-parameter and 6-parameter conditional DDPMs (CAMELS LH) side-by-side:
4
+ • Random-draw vs test-conditioned triplets (CAMELS | DDPM-2 | DDPM-6)
5
+ • Six anchor cosmologies: P(k) and PDF diagnostics (triple curves per panel where applicable)
6
+ • LHS R² cosmology plots (LHS-50 × 15 maps — expensive)
7
+ • MLP P(k) → label recovery ( sklearn MLP, two models + shared CAMELS calibration )
8
+ • Surrogate posterior on (Ωm, σ8) for a fixed test index
9
+ • Training / validation loss on one axis (Slurm .out for DDPM-6; DDPM-2 defaults to bundled JSON)
10
+
11
+ Outputs under --output-dir (default: Models/ddpm_comparison_out/).
12
+
13
+ GPU: both models are resident while generating comparison panels; use a single GPU with
14
+ sufficient memory, or run heavier steps separately with refactors.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import gc
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, Sequence, Tuple
24
+
25
+ import matplotlib
26
+
27
+ matplotlib.use("Agg")
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ import torch
31
+
32
+ # --- Repo paths ---
33
+ MODELS_ROOT = Path(__file__).resolve().parents[1]
34
+ CODE_6 = (MODELS_ROOT / "6param_ddpm_hi_lh6").resolve()
35
+ if str(CODE_6) not in sys.path:
36
+ sys.path.insert(0, str(CODE_6))
37
+
38
+ import evaluate_conditional as ec # noqa: E402
39
+ import eval_model as em # noqa: E402
40
+ from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402
41
+
42
+ from plot_r2_cosmology_lhs import compute_lhs_r2, plot_r2_cosmology_figure # noqa: E402
43
+
44
+ from compare_ddpm_training_curves import ( # noqa: E402
45
+ load_train_val_series,
46
+ parse_slurm_training_log,
47
+ )
48
+
49
+ DEFAULT_SLURM_6 = Path(
50
+ "/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/scripts/shell/slurm-698243.out"
51
+ )
52
+ # Bundled train/val (no 2-param Slurm log in-repo); see ``ddpm_2param_training_loss.json``.
53
+ DEFAULT_DDPM2_TRAINING = (Path(__file__).resolve().parent / "ddpm_2param_training_loss.json")
54
+
55
+
56
+ def _fmt_title(lab: np.ndarray) -> str:
57
+ t = np.asarray(lab, dtype=float).ravel()
58
+ if t.size <= 2:
59
+ return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}"
60
+ tail = ", ".join(f"{float(v):.3g}" for v in t[2:])
61
+ return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail
62
+
63
+
64
+ def _latin_hypercube(n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator) -> np.ndarray:
65
+ """Classic LHS (same as notebook)."""
66
+ d = int(lo.shape[0])
67
+ u = rng.random((n, d))
68
+ cut = np.linspace(0.0, 1.0, n + 1)
69
+ a, b = cut[:-1], cut[1:]
70
+ width = (b - a)[:, np.newaxis]
71
+ rd = a[:, np.newaxis] + u * width
72
+ for j in range(d):
73
+ rng.shuffle(rd[:, j])
74
+ span = (hi - lo).astype(np.float64)
75
+ return (lo + rd * span).astype(np.float32)
76
+
77
+
78
+ @torch.no_grad()
79
+ def generate_maps(
80
+ model: torch.nn.Module,
81
+ labels_np: np.ndarray,
82
+ label_mean: np.ndarray,
83
+ label_std: np.ndarray,
84
+ H: int,
85
+ W: int,
86
+ device: torch.device,
87
+ ddim_steps: int,
88
+ batch_size: int,
89
+ ) -> np.ndarray:
90
+ out = []
91
+ n = labels_np.shape[0]
92
+ for j0 in range(0, n, batch_size):
93
+ chunk = labels_np[j0 : j0 + batch_size]
94
+ bt = ec.prepare_labels_for_model(chunk.astype(np.float32), label_mean, label_std).to(device)
95
+ g = model.sample(
96
+ labels=bt,
97
+ channels=1,
98
+ height=H,
99
+ width=W,
100
+ device=device,
101
+ progress=False,
102
+ use_ddim=True,
103
+ ddim_steps=ddim_steps,
104
+ )
105
+ out.append(ec.from_model_output(g))
106
+ return np.concatenate(out, axis=0)
107
+
108
+
109
+ def load_model(bundle_args: Path, ckpt: Path, device: torch.device):
110
+ cfg = ec.load_training_config(str(bundle_args))
111
+ model = ec.build_model(cfg, device)
112
+ ec.load_checkpoint(model, str(ckpt), device)
113
+ model.eval()
114
+ return model, cfg
115
+
116
+
117
+ def free_torch():
118
+ gc.collect()
119
+ if torch.cuda.is_available():
120
+ torch.cuda.empty_cache()
121
+
122
+
123
+ def plot_training_overlay(
124
+ out_dir: Path,
125
+ slurm6: Path | None,
126
+ slurm2: Path | None,
127
+ ) -> None:
128
+ """Train + val curves for DDPM6 and optionally DDPM2 on one logarithmic-loss axis."""
129
+ fig, ax = plt.subplots(figsize=(9, 5))
130
+ plotted = False
131
+ if slurm6 and Path(slurm6).is_file():
132
+ ep, tr, va = parse_slurm_training_log(slurm6)
133
+ ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-6 train", color="#1f77b4", alpha=0.85)
134
+ ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-6 val", color="#174a75", alpha=0.95)
135
+ plotted = True
136
+ else:
137
+ print("Warning: 6-param Slurm log not found; skipped overlay for DDPM-6.")
138
+
139
+ if slurm2 and Path(slurm2).is_file():
140
+ ep, tr, va = load_train_val_series(slurm2)
141
+ ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-2 train", color="#ff7f0e", alpha=0.85)
142
+ ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-2 val", color="#994d00", alpha=0.95)
143
+ plotted = True
144
+ elif slurm2 is not None:
145
+ print(f"Warning: 2-param training series not found ({slurm2}); use --slurm-2param or restore bundled JSON.")
146
+
147
+ if not plotted:
148
+ print("No Slurm logs parsed — writing placeholder note instead of curves.")
149
+ ax.text(
150
+ 0.5,
151
+ 0.5,
152
+ "Pass --slurm-6param; DDPM-2 uses bundled JSON by default (--slurm-2param).",
153
+ ha="center",
154
+ va="center",
155
+ transform=ax.transAxes,
156
+ )
157
+ else:
158
+ ax.set_yscale("log")
159
+ ax.grid(True, alpha=0.3)
160
+ ax.set_xlabel("Epoch")
161
+ ax.set_ylabel("MSE diffusion loss")
162
+ ax.legend(loc="upper right", fontsize=8)
163
+ ax.set_title("Training / validation curves (combined)")
164
+ outp = out_dir / "comparison_training_train_val_overlay.png"
165
+ fig.savefig(outp, dpi=170, bbox_inches="tight")
166
+ plt.close(fig)
167
+ print("Saved", outp)
168
+
169
+
170
+ def run_random_theta_triplets(
171
+ out_dir: Path,
172
+ imgs6: np.ndarray,
173
+ lab6: np.ndarray,
174
+ mean6: np.ndarray,
175
+ std6: np.ndarray,
176
+ mean2: np.ndarray,
177
+ std2: np.ndarray,
178
+ model2,
179
+ model6,
180
+ device: torch.device,
181
+ ddim_steps: int,
182
+ seed: int,
183
+ n_pairs: int,
184
+ batch_size: int,
185
+ ):
186
+ """Random LHS targets in CAMELS bbox; CAMELS column = NN real map."""
187
+ rng = np.random.default_rng(seed)
188
+ lo, hi = lab6.min(0), lab6.max(0)
189
+ targets = _latin_hypercube(min(n_pairs, 32), lo, hi, rng)[:n_pairs]
190
+
191
+ H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
192
+ tg2 = targets[:, :2].astype(np.float32)
193
+
194
+ fig = plt.figure(figsize=(3.8 * max(3, n_pairs * 3), 4.1))
195
+ for i in range(n_pairs):
196
+ theta6 = targets[i].astype(np.float32)
197
+ theta2 = tg2[i]
198
+ dist = np.linalg.norm(lab6 - theta6[None, :], axis=1).astype(np.float64)
199
+ nn = int(np.argmin(dist))
200
+ nn_img = imgs6[nn]
201
+
202
+ gen2 = generate_maps(
203
+ model2, theta2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size
204
+ )
205
+ gen6 = generate_maps(
206
+ model6, theta6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size
207
+ )
208
+
209
+ titles = ("CAMELS (NN)", "DDPM-2", "DDPM-6")
210
+ for j, img in enumerate((nn_img, gen2[0], gen6[0])):
211
+ ax = fig.add_subplot(1, n_pairs * 3, i * 3 + j + 1)
212
+ ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
213
+ ax.axis("off")
214
+ ax.set_title(titles[j], fontsize=8)
215
+
216
+ plt.suptitle(
217
+ "Random LHS cosmologies — CAMELS = nearest-neighbour truth | gens conditioned on LHS labels",
218
+ fontsize=10,
219
+ y=1.02,
220
+ )
221
+ p = out_dir / "comparison_random_lhs_triplets_camels_ddpm2_ddpm6.png"
222
+ plt.tight_layout()
223
+ plt.savefig(p, dpi=160, bbox_inches="tight")
224
+ plt.close(fig)
225
+ print("Saved", p)
226
+
227
+
228
+ def run_conditioned_test_triplets(
229
+ out_dir: Path,
230
+ imgs6: np.ndarray,
231
+ lab6: np.ndarray,
232
+ mean6: np.ndarray,
233
+ std6: np.ndarray,
234
+ mean2: np.ndarray,
235
+ std2: np.ndarray,
236
+ model2,
237
+ model6,
238
+ device: torch.device,
239
+ ddim_steps: int,
240
+ seed: int,
241
+ n_pairs: int,
242
+ batch_size: int,
243
+ ):
244
+ """Same rows from test split: conditioned on truth labels."""
245
+ rng = np.random.default_rng(seed + 1)
246
+ idx = rng.choice(len(imgs6), size=min(n_pairs, len(imgs6)), replace=False)
247
+ H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
248
+ fig, axes = plt.subplots(1, n_pairs * 3, figsize=(2.9 * n_pairs * 3, 3.8), squeeze=False)
249
+ for ii, ix in enumerate(idx):
250
+ tg6 = lab6[ix].astype(np.float32)
251
+ tg2 = tg6[:2]
252
+ rm = imgs6[ix]
253
+ g2 = generate_maps(model2, tg2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size)[0]
254
+ g6 = generate_maps(model6, tg6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size)[0]
255
+ for j, img in enumerate((rm, g2, g6)):
256
+ ax = axes[0, ii * 3 + j]
257
+ ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
258
+ ax.axis("off")
259
+ if ii == 0:
260
+ ax.set_title(("CAMELS", "DDPM-2", "DDPM-6")[j], fontsize=8)
261
+ axes[0, ii * 3].set_xlabel(_fmt_title(tg6), fontsize=7)
262
+ plt.suptitle(f"Random test ix (conditioned on truth labels), n={len(idx)}", fontsize=10, y=1.06)
263
+ p = out_dir / "comparison_test_conditioned_camels_ddpm2_ddpm6.png"
264
+ plt.savefig(p, dpi=160, bbox_inches="tight")
265
+ plt.close(fig)
266
+ print("Saved", p)
267
+
268
+
269
+ def pk_pdf_six_sets(
270
+ out_dir: Path,
271
+ name: str,
272
+ images_split: np.ndarray,
273
+ labels_split: np.ndarray,
274
+ label_mean: np.ndarray,
275
+ label_std: np.ndarray,
276
+ model,
277
+ device: torch.device,
278
+ ddim_steps: int,
279
+ batch_size: int,
280
+ n_per_set: int,
281
+ ):
282
+ """Six anchor rows (evenly spaced ix in test split), N_PER_SET DDIM samples."""
283
+ H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
284
+ ldim = int(labels_split.shape[1])
285
+
286
+ idx = np.linspace(0, len(labels_split) - 1, num=6, dtype=int)
287
+ targets = labels_split[idx].copy()
288
+
289
+ box = 25.0
290
+ dk_ref = None
291
+ panels_pk = []
292
+
293
+ rng_pdf_bins = np.linspace(14.0, 22.0, 101)
294
+ bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
295
+
296
+ fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
297
+ axes_pk = axes_pk.ravel()
298
+
299
+ fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 4.8 * 2), squeeze=False)
300
+
301
+ for si, target_l in enumerate(targets):
302
+ dist = np.linalg.norm(labels_split - target_l[None, :], axis=1).astype(np.float64)
303
+ ex = idx[si]
304
+ dist[ex] = np.inf if ex < len(dist) else np.inf
305
+ nn_idx = np.argsort(dist)[:n_per_set]
306
+ real_batch = images_split[nn_idx]
307
+ rep = np.tile(target_l[None, :], (n_per_set, 1))
308
+ gen = generate_maps(model, rep, label_mean, label_std, H, W, device, ddim_steps, batch_size)
309
+
310
+ dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
311
+ dk_g, mg, sg = ec.calculate_power_spectrum_batch(gen, box_size=box)
312
+ dk_ref = dk_r
313
+ x = dk_ref[1:]
314
+ axpk = axes_pk[si]
315
+ axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#333")
316
+ axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.08, color="#333")
317
+ axpk.plot(x, mg[1:], lw=2, label=f"Generated (ldim={ldim})", color="#d95f02")
318
+ axpk.fill_between(x, mg[1:] - sg[1:], mg[1:] + sg[1:], alpha=0.08, color="#d95f02")
319
+ axpk.set_yscale("log")
320
+ axpk.grid(alpha=0.25)
321
+ axpk.set_title(_fmt_title(target_l), fontsize=8)
322
+ panels_pk.append((si, dk_r, mr, sr, mg, sg))
323
+
324
+ # PDF µ/σ
325
+ tb = []; rb = []
326
+ for i in range(n_per_set):
327
+ for arr, store in zip((real_batch, gen), (tb, rb)):
328
+ ims = np.clip(arr[i].ravel(), 0.0, 1.0)
329
+ logn = 14.0 + (22.0 - 14.0) * ims
330
+ hst, _ = np.histogram(logn, bins=rng_pdf_bins, density=True)
331
+ store.append(hst)
332
+ tb = np.asarray(tb); rb = np.asarray(rb)
333
+
334
+ axes_pdf[si, 0].plot(bin_pdf, tb.mean(axis=0), lw=2, label="CAMELS NN", color="#333")
335
+ axes_pdf[si, 0].plot(bin_pdf, rb.mean(axis=0), lw=2, label="Generated", color="#d95f02")
336
+ axes_pdf[si, 1].plot(bin_pdf, tb.std(axis=0), lw=2, ls="-", label="CAMELS σ", color="#333")
337
+ axes_pdf[si, 1].plot(bin_pdf, rb.std(axis=0), lw=2, ls="--", label="Gen σ", color="#d95f02")
338
+
339
+ axes_pk[0].legend(fontsize=7, loc="lower left")
340
+ fig_pk.suptitle(f"$P(k)$ mean±std — six anchors — {name}", fontsize=10)
341
+ fig_pk.tight_layout()
342
+ p_pk = out_dir / f"six_anchor_pk_{name}.png"
343
+ fig_pk.savefig(p_pk, dpi=160)
344
+ plt.close(fig_pk)
345
+
346
+ axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
347
+ axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
348
+ fig_pdf.suptitle(f"PDF mean & σ — six anchors × {n_per_set} — {name}")
349
+ fig_pdf.tight_layout()
350
+ p_pdf = out_dir / f"six_anchor_pdf_mu_sigma_{name}.png"
351
+ fig_pdf.savefig(p_pdf, dpi=160)
352
+ plt.close(fig_pdf)
353
+ print("Saved", p_pk)
354
+ print("Saved", p_pdf)
355
+
356
+
357
+ def pk_six_triplet_combined(
358
+ out_dir: Path,
359
+ imgs6: np.ndarray,
360
+ lab6: np.ndarray,
361
+ mean6: np.ndarray,
362
+ std6: np.ndarray,
363
+ mean2: np.ndarray,
364
+ std2: np.ndarray,
365
+ model2: torch.nn.Module,
366
+ model6: torch.nn.Module,
367
+ device: torch.device,
368
+ ddim_steps: int,
369
+ batch_size: int,
370
+ n_per_set: int,
371
+ ) -> None:
372
+ """Six anchors — mean P(k) for CAMELS vs DDPM-2 vs DDPM-6; analogous PDF overlays."""
373
+ H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
374
+ idx = np.linspace(0, len(lab6) - 1, num=6, dtype=int)
375
+ targets = lab6[idx].copy()
376
+ box = 25.0
377
+
378
+ fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
379
+ axes_pk = axes_pk.ravel()
380
+ rng_pdf_bins = np.linspace(14.0, 22.0, 101)
381
+ bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
382
+ fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 10.5))
383
+
384
+ for si, target_l in enumerate(targets):
385
+ dist = np.linalg.norm(lab6 - target_l[None, :], axis=1).astype(np.float64)
386
+ ex = int(idx[si])
387
+ if ex < len(dist):
388
+ dist = dist.copy()
389
+ dist[ex] = np.inf
390
+ nn_idx = np.argsort(dist)[:n_per_set]
391
+ real_batch = imgs6[nn_idx]
392
+
393
+ tg2 = np.tile(target_l[:2][None, :], (n_per_set, 1)).astype(np.float32)
394
+ tg6 = np.tile(target_l[None, :], (n_per_set, 1)).astype(np.float32)
395
+
396
+ gen2 = generate_maps(model2, tg2, mean2, std2, H, W, device, ddim_steps, batch_size)
397
+ gen6 = generate_maps(model6, tg6, mean6, std6, H, W, device, ddim_steps, batch_size)
398
+
399
+ dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
400
+ _, m2, s2 = ec.calculate_power_spectrum_batch(gen2, box_size=box)
401
+ _, mG, sG = ec.calculate_power_spectrum_batch(gen6, box_size=box)
402
+ x = dk_r[1:]
403
+
404
+ axpk = axes_pk[si]
405
+ axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#222")
406
+ axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.06, color="#222")
407
+ axpk.plot(x, m2[1:], lw=2, label="DDPM-2 μ", color="#ff7f0e")
408
+ axpk.fill_between(x, m2[1:] - s2[1:], m2[1:] + s2[1:], alpha=0.06, color="#ff7f0e")
409
+ axpk.plot(x, mG[1:], lw=2, label="DDPM-6 μ", color="#1f77b4")
410
+ axpk.fill_between(x, mG[1:] - sG[1:], mG[1:] + sG[1:], alpha=0.06, color="#1f77b4")
411
+ axpk.set_yscale("log")
412
+ axpk.grid(alpha=0.25)
413
+ axpk.set_title(_fmt_title(target_l), fontsize=8)
414
+ if si == 0:
415
+ axpk.legend(fontsize=6.2, loc="lower left")
416
+
417
+ pdf_rows_lists = []
418
+ for imgs in (real_batch, gen2, gen6):
419
+ hb = []
420
+ for i in range(min(n_per_set, len(imgs))):
421
+ px = np.clip(imgs[i].ravel(), 0.0, 1.0)
422
+ ln = 14.0 + (22.0 - 14.0) * px
423
+ hb.append(np.histogram(ln, bins=rng_pdf_bins, density=True)[0])
424
+ pdf_rows_lists.append(np.asarray(hb))
425
+
426
+ cam_pdf, d2_pdf, d6_pdf = pdf_rows_lists
427
+ axes_pdf[si, 0].plot(bin_pdf, cam_pdf.mean(axis=0), lw=2, color="#222", label="CAMELS μ")
428
+ axes_pdf[si, 0].plot(bin_pdf, d2_pdf.mean(axis=0), lw=2, color="#ff7f0e", label="DDPM-2 μ")
429
+ axes_pdf[si, 0].plot(bin_pdf, d6_pdf.mean(axis=0), lw=2, color="#1f77b4", label="DDPM-6 μ")
430
+ axes_pdf[si, 1].plot(bin_pdf, cam_pdf.std(axis=0), lw=2, color="#222")
431
+ axes_pdf[si, 1].plot(bin_pdf, d2_pdf.std(axis=0), lw=2, ls="--", color="#ff7f0e")
432
+ axes_pdf[si, 1].plot(bin_pdf, d6_pdf.std(axis=0), lw=2, ls="--", color="#1f77b4")
433
+
434
+ fig_pk.suptitle("$P(k)$ CAMELS vs DDPM-2 vs DDPM-6 — six Ωm–σ8 anchors", fontsize=11)
435
+ fig_pk.tight_layout()
436
+ p_pk = out_dir / "six_anchor_pk_overlay_camels_ddpm2_ddpm6.png"
437
+ fig_pk.savefig(p_pk, dpi=160)
438
+ plt.close(fig_pk)
439
+
440
+ axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
441
+ axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
442
+ fig_pdf.suptitle(r"PDF mean ($\mu$) and std ($\sigma$) overlays", fontsize=10)
443
+ fig_pdf.tight_layout()
444
+ p_pdf = out_dir / "six_anchor_pdf_overlay_camels_ddpm2_ddpm6.png"
445
+ fig_pdf.savefig(p_pdf, dpi=160)
446
+ plt.close(fig_pdf)
447
+ print("Saved", p_pk)
448
+ print("Saved", p_pdf)
449
+
450
+
451
+ def mlp_recovery_dual(
452
+ out_dir: Path,
453
+ data_train: Path,
454
+ imgs_te: np.ndarray,
455
+ lab_te: np.ndarray,
456
+ mean: np.ndarray,
457
+ std: np.ndarray,
458
+ model_ddpm: torch.nn.Module,
459
+ tag: str,
460
+ device: torch.device,
461
+ ddim_steps: int,
462
+ seed: int,
463
+ ) -> None:
464
+ from sklearn.metrics import mean_squared_error
465
+ from sklearn.neural_network import MLPRegressor
466
+
467
+ ldim = lab_te.shape[1]
468
+ Npix = imgs_te.shape[-1]
469
+ dl = 25.0 / Npix
470
+
471
+ def pk_row(im):
472
+ _dk, pk = ec.PowerSpectrum(np.asarray(im, dtype=np.float64), N=Npix, dl=dl)
473
+ return pk[1:].astype(np.float32)
474
+
475
+ img_tr_np, lab_tr_np = ec.load_split(data_train, "train")
476
+ if len(img_tr_np) > 2000:
477
+ rng = np.random.default_rng(seed)
478
+ jj = rng.choice(len(img_tr_np), 2000, replace=False)
479
+ img_tr_np, lab_tr_np = img_tr_np[jj], lab_tr_np[jj]
480
+ X_train = np.stack([pk_row(img_tr_np[i]) for i in range(len(img_tr_np))], axis=0)
481
+ y_train = lab_tr_np.astype(np.float32)
482
+ mlp = MLPRegressor(
483
+ hidden_layer_sizes=(64, 64),
484
+ alpha=1e-4,
485
+ random_state=seed,
486
+ max_iter=250,
487
+ early_stopping=True,
488
+ validation_fraction=0.1,
489
+ )
490
+ mlp.fit(X_train, y_train)
491
+
492
+ n_ev = min(40, len(imgs_te))
493
+ eval_idx = np.arange(n_ev)
494
+ X_real = np.stack([pk_row(imgs_te[i]) for i in eval_idx], axis=0)
495
+ y_true = lab_te[eval_idx]
496
+
497
+ preds_real = mlp.predict(X_real)
498
+
499
+ gens = []
500
+ H, W = int(imgs_te.shape[-2]), int(imgs_te.shape[-1])
501
+ for i0 in range(0, n_ev, 8):
502
+ bs_chunk = min(8, n_ev - i0)
503
+ lbl = y_true[i0 : i0 + bs_chunk]
504
+ g = generate_maps(model_ddpm, lbl, mean, std, H, W, device, ddim_steps, bs_chunk)
505
+ gens.extend([pk_row(g[j]) for j in range(len(g))])
506
+ X_gen = np.stack(gens, axis=0)
507
+ preds_gen = mlp.predict(X_gen)
508
+
509
+ rmse_real = np.sqrt(mean_squared_error(y_true, preds_real, multioutput="raw_values"))
510
+ rmse_gen = np.sqrt(mean_squared_error(y_true, preds_gen, multioutput="raw_values"))
511
+
512
+ fig, axes = plt.subplots(2, ldim, figsize=(max(9.0, 2.8 * max(ldim, 2)), 4.9), squeeze=False)
513
+ if ldim == 1:
514
+ axes = np.reshape(axes, (2, 1))
515
+ for k in range(ldim):
516
+ for row, preds, rmv, ylab in (
517
+ (0, preds_real, rmse_real, "CAMELS P(k) predictions"),
518
+ (1, preds_gen, rmse_gen, f"{tag}: generated P(k)"),
519
+ ):
520
+ ax = axes[row, k]
521
+ lo = float(y_true[:, k].min()); hi = float(y_true[:, k].max())
522
+ pad = 0.03 * (hi - lo + 1e-12)
523
+ ax.scatter(y_true[:, k], preds[:, k], s=14, alpha=0.72, edgecolors="none", c="#333")
524
+ ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], color="crimson", lw=1.0)
525
+ ax.grid(True, alpha=0.28)
526
+ ax.set_title(f"dim {k} RMSE={float(rmv[k]):.4f}", fontsize=8)
527
+ if k == 0:
528
+ ax.set_ylabel(ylab, fontsize=8)
529
+
530
+ plt.suptitle(
531
+ "MLP: train on CAMELS train P(k), test on CAMELS vs DDPM-drawn spectra",
532
+ fontsize=10,
533
+ y=1.02,
534
+ )
535
+ plt.tight_layout()
536
+ p = out_dir / f"mlp_pk_parameter_recovery_{tag}.png"
537
+ plt.savefig(p, dpi=165, bbox_inches="tight")
538
+ plt.close(fig)
539
+ print("Saved", p)
540
+
541
+
542
+ def posterior_one_index(
543
+ out_dir: Path,
544
+ images_split: np.ndarray,
545
+ labels_split: np.ndarray,
546
+ lab_mean: np.ndarray,
547
+ lab_std: np.ndarray,
548
+ model,
549
+ cfg: Dict,
550
+ device,
551
+ ix: int,
552
+ tag: str,
553
+ ddim_steps: int,
554
+ grid: int,
555
+ batch_sz: int,
556
+ ):
557
+ normalize = bool(cfg.get("normalize_labels", True))
558
+ lab_dim = labels_split.shape[1]
559
+ H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
560
+ obs = images_split[ix]
561
+ label_anchor_full = labels_split[ix].astype(np.float32)
562
+
563
+ lo0 = float(labels_split[:, 0].min())
564
+ hi0 = float(labels_split[:, 0].max())
565
+ lo1 = float(labels_split[:, 1].min())
566
+ hi1 = float(labels_split[:, 1].max())
567
+ pad0 = 0.02 * (hi0 - lo0 + 1e-12)
568
+ pad1 = 0.02 * (hi1 - lo1 + 1e-12)
569
+ om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid(
570
+ grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1
571
+ )
572
+ g = grid
573
+ ngrid = grid2.shape[0]
574
+ npix = int(obs.shape[-1])
575
+ dl = 25.0 / npix
576
+ dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl)
577
+ valid = dk > 0
578
+ log_pd = log_pk_observed(obs, 25.0, dk)
579
+
580
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
581
+ full = np.tile(label_anchor_full[np.newaxis, :], (ngrid, 1))
582
+ full[:, 0] = grid2[:, 0].astype(np.float32)
583
+ full[:, 1] = grid2[:, 1].astype(np.float32)
584
+
585
+ def weights_full() -> np.ndarray:
586
+ scores = []
587
+ for j0 in range(0, ngrid, batch_sz):
588
+ chunk = full[j0 : j0 + batch_sz]
589
+ imgs = em.sample_batch(
590
+ model,
591
+ chunk,
592
+ lab_mean,
593
+ lab_std,
594
+ normalize,
595
+ H,
596
+ W,
597
+ device,
598
+ ddim_steps,
599
+ False,
600
+ )
601
+ _, pkc = em.per_map_power_spectra_log(imgs, 25.0)
602
+ log_pg = np.log(pkc[:, valid] + 1e-30)
603
+ mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1)
604
+ scores.append(-mse / (2.0 * 0.25**2))
605
+ sc = np.concatenate(scores)
606
+ sc -= sc.max()
607
+ w = np.exp(sc).reshape(g, g)
608
+ w /= w.sum()
609
+ return w
610
+
611
+ Wmap = weights_full()
612
+ tom, ts8 = float(label_anchor_full[0]), float(label_anchor_full[1])
613
+ mom = float((Wmap * OM).sum())
614
+ ms8 = float((Wmap * S8).sum())
615
+
616
+ fig, ax = plt.subplots(figsize=(5.2, 4.6))
617
+ cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues")
618
+ plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
619
+ ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true")
620
+ ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean")
621
+ ax.set_xlabel(r"$\Omega_m$")
622
+ ax.set_ylabel(r"$\sigma_8$")
623
+ ax.legend(fontsize=8)
624
+ ax.set_title(f"Surrogate posterior (test ix={ix}, ldim={lab_dim})", fontsize=10)
625
+ p = out_dir / f"posterior_surrogate_test_ix_{ix}_{tag}.png"
626
+ fig.savefig(p, dpi=160, bbox_inches="tight")
627
+ plt.close(fig)
628
+ print("Saved", p)
629
+
630
+
631
+ def main(argv: Sequence[str] | None = None) -> None:
632
+ p = argparse.ArgumentParser(description="DDPM-2 vs DDPM-6 comparison suite.")
633
+ p.add_argument(
634
+ "--output-dir",
635
+ type=Path,
636
+ default=MODELS_ROOT / "ddpm_comparison_out",
637
+ )
638
+ p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
639
+ p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
640
+ p.add_argument(
641
+ "--bundle-2param",
642
+ type=Path,
643
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
644
+ )
645
+ p.add_argument(
646
+ "--bundle-6param",
647
+ type=Path,
648
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
649
+ )
650
+ p.add_argument("--posterior-index", type=int, default=56)
651
+ p.add_argument("--lhs-n", type=int, default=50)
652
+ p.add_argument("--six-n-per-anchor", type=int, default=15)
653
+ p.add_argument("--ddim-steps", type=int, default=50)
654
+ p.add_argument("--seed", type=int, default=42)
655
+ p.add_argument("--batch-size", type=int, default=8)
656
+ p.add_argument("--slurm-6param", type=Path, default=DEFAULT_SLURM_6)
657
+ p.add_argument(
658
+ "--slurm-2param",
659
+ type=Path,
660
+ default=DEFAULT_DDPM2_TRAINING,
661
+ help="DDPM-2 train/val series: Slurm .out (parsed) or bundled ddpm_2param_training_loss.json.",
662
+ )
663
+ p.add_argument("--skip-lhs-r2", action="store_true", help="LHS R² plots are expensive; skip if set.")
664
+ p.add_argument("--n-random-triplets", type=int, default=4)
665
+ args = p.parse_args(list(argv) if argv is not None else None)
666
+
667
+ out_dir = Path(args.output_dir).resolve()
668
+ out_dir.mkdir(parents=True, exist_ok=True)
669
+
670
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
671
+ print("device:", device)
672
+
673
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
674
+ args2 = args.bundle_2param / "args.json"
675
+ ck6 = args.bundle_6param / "best_model.pt"
676
+ args6 = args.bundle_6param / "args.json"
677
+
678
+ data2 = Path(args.data_2param)
679
+ data6 = Path(args.data_6param)
680
+
681
+ imgs6, lab6 = ec.load_split(data6, "test")
682
+ mean6, std6 = ec.load_label_stats(data6)
683
+ mean2, std2 = ec.load_label_stats(data2)
684
+
685
+ plot_training_overlay(out_dir, args.slurm_6param, args.slurm_2param)
686
+
687
+ imgs2, lab2 = ec.load_split(data2, "test")
688
+
689
+ print(">>> Loading DDPM-2...")
690
+ model2, cfg2 = load_model(args2, ck2, device)
691
+ print(">>> Loading DDPM-6...")
692
+ model6, cfg6 = load_model(args6, ck6, device)
693
+
694
+ print(">>> Random LHS + conditioned triplets...")
695
+ try:
696
+ run_random_theta_triplets(
697
+ out_dir,
698
+ imgs6,
699
+ lab6,
700
+ mean6,
701
+ std6,
702
+ mean2,
703
+ std2,
704
+ model2,
705
+ model6,
706
+ device=device,
707
+ ddim_steps=args.ddim_steps,
708
+ seed=args.seed,
709
+ n_pairs=args.n_random_triplets,
710
+ batch_size=args.batch_size,
711
+ )
712
+ run_conditioned_test_triplets(
713
+ out_dir,
714
+ imgs6,
715
+ lab6,
716
+ mean6,
717
+ std6,
718
+ mean2,
719
+ std2,
720
+ model2,
721
+ model6,
722
+ device=device,
723
+ ddim_steps=args.ddim_steps,
724
+ seed=args.seed,
725
+ n_pairs=args.n_random_triplets,
726
+ batch_size=args.batch_size,
727
+ )
728
+ except Exception as exc:
729
+ print("Triplet grids failed:", exc)
730
+
731
+ print(">>> Six-anchor overlays (combined + per-model)...")
732
+ try:
733
+ pk_six_triplet_combined(
734
+ out_dir,
735
+ imgs6,
736
+ lab6,
737
+ mean6,
738
+ std6,
739
+ mean2,
740
+ std2,
741
+ model2,
742
+ model6,
743
+ device=device,
744
+ ddim_steps=args.ddim_steps,
745
+ batch_size=args.batch_size,
746
+ n_per_set=args.six_n_per_anchor,
747
+ )
748
+ pk_pdf_six_sets(
749
+ out_dir,
750
+ "ddpm6_only",
751
+ imgs6,
752
+ lab6,
753
+ mean6,
754
+ std6,
755
+ model6,
756
+ device,
757
+ args.ddim_steps,
758
+ args.batch_size,
759
+ args.six_n_per_anchor,
760
+ )
761
+ pk_pdf_six_sets(
762
+ out_dir,
763
+ "ddpm2_only",
764
+ imgs2,
765
+ lab2,
766
+ mean2,
767
+ std2,
768
+ model2,
769
+ device,
770
+ args.ddim_steps,
771
+ args.batch_size,
772
+ args.six_n_per_anchor,
773
+ )
774
+ except Exception as exc:
775
+ print("P(k)/PDF six-anchor plots failed:", exc)
776
+
777
+ if not args.skip_lhs_r2:
778
+ print(">>> LHS R² (LHS-50 × 15 DDIM each — long)...")
779
+ try:
780
+ for label, imgs, labs, mn, sd, mdl in (
781
+ ("ddpm2_lhs50", imgs2, lab2, mean2, std2, model2),
782
+ ("ddpm6_lhs50", imgs6, lab6, mean6, std6, model6),
783
+ ):
784
+ lhs_pts, r2_mu, r2_sig, lo_b, hi_b = compute_lhs_r2(
785
+ mdl,
786
+ imgs,
787
+ labs,
788
+ mn,
789
+ sd,
790
+ device,
791
+ args.lhs_n,
792
+ 15,
793
+ args.batch_size,
794
+ 25.0,
795
+ args.ddim_steps,
796
+ args.seed,
797
+ )
798
+ outp = out_dir / f"r2_cosmology_lhs{args.lhs_n}_{label}.png"
799
+ plot_r2_cosmology_figure(lhs_pts, r2_mu, r2_sig, lo_b, hi_b, outp, dpi=160)
800
+ print("Saved", outp)
801
+ np.savez(
802
+ out_dir / f"r2_lhs_data_{label}.npz",
803
+ lhs_pts=lhs_pts,
804
+ r2_mu_arr=r2_mu,
805
+ r2_sig_arr=r2_sig,
806
+ lo_b=lo_b,
807
+ hi_b=hi_b,
808
+ )
809
+ except Exception as exc:
810
+ print("LHS R² skipped:", exc)
811
+ else:
812
+ print("(Skipping LHS R².)")
813
+
814
+ print(">>> MLP P(k) parameter recovery...")
815
+ try:
816
+ mlp_recovery_dual(
817
+ out_dir, data2, imgs2[:40], lab2[:40], mean2, std2, model2, "ddpm2param", device, args.ddim_steps, args.seed
818
+ )
819
+ mlp_recovery_dual(
820
+ out_dir, data6, imgs6[:40], lab6[:40], mean6, std6, model6, "ddpm6param", device, args.ddim_steps, args.seed
821
+ )
822
+ except Exception as exc:
823
+ print("MLP recovery skipped:", exc)
824
+
825
+ print(f">>> Surrogate posteriors (test index {args.posterior_index})...")
826
+ try:
827
+ ix = int(args.posterior_index)
828
+ posterior_one_index(
829
+ out_dir, imgs6, lab6, mean6, std6, model6, cfg6, device, ix, "ddpm6", args.ddim_steps, 14, args.batch_size
830
+ )
831
+ posterior_one_index(
832
+ out_dir,
833
+ imgs2,
834
+ lab2,
835
+ mean2,
836
+ std2,
837
+ model2,
838
+ cfg2,
839
+ device,
840
+ ix,
841
+ "ddpm2",
842
+ args.ddim_steps,
843
+ 14,
844
+ args.batch_size,
845
+ )
846
+ except Exception as exc:
847
+ print("Posterior panels skipped:", exc)
848
+
849
+ del model2, model6
850
+ free_torch()
851
+ print(f"Done. Outputs in {out_dir}")
852
+
853
+
854
+ if __name__ == "__main__":
855
+ main()
cross_model/scripts/compare_ddpm_training_curves.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Parse DDPM Slurm stdout or bundled JSON for Train/Val loss series."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import re
8
+ from pathlib import Path
9
+ from typing import Tuple
10
+
11
+ _ROW = re.compile(
12
+ r"Epoch\s+(?P<ep>\d+)/\d+\s+\|\s+Train:\s+(?P<tr>[\d.eE+-]+)\s+\|\s+Val:\s+(?P<va>[\d.eE+-]+)",
13
+ )
14
+
15
+
16
+ def parse_slurm_training_log(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
17
+ """Return (epochs, train_losses, val_losses) parsed from Slurm *.out stdout."""
18
+ p = Path(path)
19
+ text = p.read_text(encoding="utf-8", errors="replace")
20
+ epochs, trains, vals = [], [], []
21
+ for m in _ROW.finditer(text):
22
+ epochs.append(int(m.group("ep")))
23
+ trains.append(float(m.group("tr")))
24
+ vals.append(float(m.group("va")))
25
+ return epochs, trains, vals
26
+
27
+
28
+ def load_training_loss_json(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
29
+ """Return (epochs, train_losses, val_losses) from a JSON export (keys: epochs, train, val)."""
30
+ p = Path(path)
31
+ raw = json.loads(p.read_text(encoding="utf-8"))
32
+ epochs = [int(e) for e in raw["epochs"]]
33
+ trains = [float(x) for x in raw["train"]]
34
+ vals = [float(x) for x in raw["val"]]
35
+ if not (len(epochs) == len(trains) == len(vals)):
36
+ raise ValueError(f"{p}: mismatched lengths in epochs/train/val")
37
+ return epochs, trains, vals
38
+
39
+
40
+ def load_train_val_series(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
41
+ """Slurm *.out or *.json with the same semantic output as ``parse_slurm_training_log``."""
42
+ p = Path(path)
43
+ if p.suffix.lower() == ".json":
44
+ return load_training_loss_json(p)
45
+ return parse_slurm_training_log(p)
cross_model/scripts/ddpm_figure6_integration.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Figure 6 style (arXiv:2409.09101) helpers for DDPM surrogate posteriors — use with ddpm_posterior_six_anchors / run_ddpm_figure6_suite.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from pathlib import Path
8
+
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from matplotlib import gridspec
12
+
13
+ from figure6_2409_style import (
14
+ create_comparison_marginal_vs_profile,
15
+ create_figure6_style_plot,
16
+ )
17
+ from sigma_contour_utils import compute_sigma_levels
18
+
19
+
20
+ def integrate_figure6_with_ddpm2(
21
+ Wmap: np.ndarray,
22
+ om_grid: np.ndarray,
23
+ s8_grid: np.ndarray,
24
+ true_om: float,
25
+ true_s8: float,
26
+ test_index: int,
27
+ output_dir: Path,
28
+ model_name: str = "DDPM-2",
29
+ ) -> None:
30
+ """Single-map Figure 6 style: marginal and profile PNGs."""
31
+ output_dir = Path(output_dir)
32
+ output_dir.mkdir(parents=True, exist_ok=True)
33
+ nm = model_name.replace(" ", "-").lower()
34
+
35
+ fig_marginal = create_figure6_style_plot(
36
+ Wmap,
37
+ om_grid,
38
+ s8_grid,
39
+ true_param1=true_om,
40
+ true_param2=true_s8,
41
+ param1_label=r"$\Omega_m$",
42
+ param2_label=r"$\sigma_8$",
43
+ title=f"{model_name} — Test ix={test_index} (Marginal)",
44
+ show_profile=False,
45
+ figsize=(10, 10),
46
+ )
47
+
48
+ save_path_marginal = output_dir / f"fig6_style_{nm}_ix{test_index}_marginal.png"
49
+ fig_marginal.savefig(save_path_marginal, dpi=200, bbox_inches="tight")
50
+ plt.close(fig_marginal)
51
+ print(f"Saved: {save_path_marginal}")
52
+
53
+ fig_profile = create_figure6_style_plot(
54
+ Wmap,
55
+ om_grid,
56
+ s8_grid,
57
+ true_param1=true_om,
58
+ true_param2=true_s8,
59
+ param1_label=r"$\Omega_m$",
60
+ param2_label=r"$\sigma_8$",
61
+ title=f"{model_name} — Test ix={test_index} (Profile)",
62
+ show_profile=True,
63
+ figsize=(10, 10),
64
+ )
65
+
66
+ save_path_profile = output_dir / f"fig6_style_{nm}_ix{test_index}_profile.png"
67
+ fig_profile.savefig(save_path_profile, dpi=200, bbox_inches="tight")
68
+ plt.close(fig_profile)
69
+ print(f"Saved: {save_path_profile}")
70
+
71
+ fig_cmp = create_comparison_marginal_vs_profile(
72
+ Wmap,
73
+ om_grid,
74
+ s8_grid,
75
+ true_param1=true_om,
76
+ true_param2=true_s8,
77
+ title=f"{model_name} marginal vs profile — ix={test_index}",
78
+ figsize=(11, 4.2),
79
+ )
80
+ cmp_path = output_dir / f"fig6_marg_vs_prof_{nm}_ix{test_index}.png"
81
+ fig_cmp.savefig(cmp_path, dpi=185, bbox_inches="tight")
82
+ plt.close(fig_cmp)
83
+ print(f"Saved: {cmp_path}")
84
+
85
+
86
+ def integrate_figure6_with_multi_anchor(
87
+ posteriors_list: list[np.ndarray],
88
+ om_grid: np.ndarray,
89
+ s8_grid: np.ndarray,
90
+ true_values_list: list[tuple[float, float]],
91
+ test_indices: list[int],
92
+ output_dir: Path,
93
+ model_name: str = "DDPM-2",
94
+ ) -> None:
95
+ """2×3 grid with Figure–6-ish 2D + top marginal."""
96
+ output_dir = Path(output_dir)
97
+ nm = model_name.replace(" ", "-").lower()
98
+ fig = plt.figure(figsize=(20, 14))
99
+ gs_o = gridspec.GridSpec(2, 3, figure=fig, hspace=0.33, wspace=0.32)
100
+
101
+ for idx, (posterior, true_vals, test_ix) in enumerate(
102
+ zip(posteriors_list, true_values_list, test_indices)
103
+ ):
104
+ true_om, true_s8 = true_vals
105
+ row, col = divmod(idx, 3)
106
+
107
+ posterior_norm = posterior / posterior.sum()
108
+ sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
109
+ P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
110
+
111
+ gs_sub = gridspec.GridSpecFromSubplotSpec(
112
+ 2,
113
+ 2,
114
+ subplot_spec=gs_o[row, col],
115
+ width_ratios=[4, 1],
116
+ height_ratios=[1, 4],
117
+ hspace=0.06,
118
+ wspace=0.06,
119
+ )
120
+ ax_main = fig.add_subplot(gs_sub[1, 0])
121
+ ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
122
+
123
+ ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
124
+ if len(set(sigma_levels)) >= 1:
125
+ ax_main.contour(
126
+ P1,
127
+ P2,
128
+ posterior_norm,
129
+ levels=sigma_levels,
130
+ colors=["darkblue", "steelblue"],
131
+ linewidths=[2.0, 1.5],
132
+ )
133
+
134
+ ax_main.scatter(true_om, true_s8, s=100, c="red", marker="x", linewidths=2.5, zorder=10)
135
+ ax_main.set_xlim(om_grid[0], om_grid[-1])
136
+ ax_main.set_ylim(s8_grid[0], s8_grid[-1])
137
+ ax_main.set_xlabel(r"$\Omega_m$" if row == 1 else "", fontsize=11)
138
+ ax_main.set_ylabel(r"$\sigma_8$" if col == 0 else "", fontsize=11)
139
+ ax_main.set_title(f"Test ix={test_ix}", fontsize=11, pad=5)
140
+ ax_main.grid(True, alpha=0.2)
141
+
142
+ marginal_om = posterior_norm.sum(axis=1)
143
+ marginal_om /= marginal_om.sum() + 1e-30
144
+ ax_top.fill_between(
145
+ om_grid,
146
+ 0.0,
147
+ marginal_om,
148
+ alpha=0.6,
149
+ color="steelblue",
150
+ edgecolor="steelblue",
151
+ )
152
+ ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2)
153
+ ax_top.set_xlim(om_grid[0], om_grid[-1])
154
+ ax_top.set_ylim(0, marginal_om.max() * 1.1)
155
+ ax_top.tick_params(labelbottom=False, labelsize=9)
156
+ ax_top.set_ylabel("$P(\\Omega_m)$", fontsize=9)
157
+ ax_top.grid(True, alpha=0.2)
158
+
159
+ ax_side = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
160
+ marginal_s8 = posterior_norm.sum(axis=0)
161
+ marginal_s8 /= marginal_s8.sum() + 1e-30
162
+ ax_side.fill_betweenx(s8_grid, 0.0, marginal_s8, alpha=0.6, color="steelblue", edgecolor="steelblue")
163
+ ax_side.axhline(true_s8, color="red", linestyle="--", linewidth=2)
164
+ ax_side.set_ylim(s8_grid[0], s8_grid[-1])
165
+ ax_side.set_xlim(0, marginal_s8.max() * 1.15)
166
+ ax_side.tick_params(labelleft=False)
167
+
168
+ fig.suptitle(
169
+ f"{model_name} — Figure 6 Style: Six Test Anchors",
170
+ fontsize=15,
171
+ y=0.995,
172
+ fontweight="bold",
173
+ )
174
+
175
+ save_path = output_dir / f"fig6_style_{nm}_all_anchors.png"
176
+ fig.savefig(save_path, dpi=200, bbox_inches="tight")
177
+ plt.close(fig)
178
+ print(f"Saved multi-anchor grid: {save_path}")
179
+
180
+
181
+ def integrate_figure6_model_comparison(
182
+ posteriors_dict: dict[str, np.ndarray],
183
+ om_grid: np.ndarray,
184
+ s8_grid: np.ndarray,
185
+ true_om: float,
186
+ true_s8: float,
187
+ test_index: int,
188
+ output_dir: Path,
189
+ ) -> None:
190
+ """Side-by-side model comparison panels."""
191
+ output_dir = Path(output_dir)
192
+ n_models = len(posteriors_dict)
193
+ fig = plt.figure(figsize=(8 * max(1, min(n_models, 4)), 8))
194
+ gs_outer = gridspec.GridSpec(1, n_models, figure=fig, wspace=0.32)
195
+
196
+ for idx, (model_name, posterior) in enumerate(posteriors_dict.items()):
197
+ gs_sub = gridspec.GridSpecFromSubplotSpec(
198
+ 2,
199
+ 2,
200
+ subplot_spec=gs_outer[0, idx],
201
+ width_ratios=[4, 1],
202
+ height_ratios=[1, 4],
203
+ hspace=0.06,
204
+ wspace=0.06,
205
+ )
206
+ posterior_norm = posterior / posterior.sum()
207
+ sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
208
+ P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
209
+
210
+ ax_main = fig.add_subplot(gs_sub[1, 0])
211
+ ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
212
+
213
+ ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
214
+ if len(set(sigma_levels)) >= 1:
215
+ ax_main.contour(
216
+ P1,
217
+ P2,
218
+ posterior_norm,
219
+ levels=sigma_levels,
220
+ colors=["darkblue", "steelblue"],
221
+ linewidths=[2.5, 2.0],
222
+ )
223
+
224
+ ax_main.scatter(true_om, true_s8, s=120, c="red", marker="x", linewidths=3, zorder=10)
225
+ ax_main.set_xlabel(r"$\Omega_m$", fontsize=13)
226
+ ax_main.set_ylabel(r"$\sigma_8$", fontsize=13)
227
+ ax_main.set_title(model_name, fontsize=13, pad=10, fontweight="bold")
228
+ ax_main.grid(True, alpha=0.3)
229
+ ax_main.set_xlim(om_grid[0], om_grid[-1])
230
+ ax_main.set_ylim(s8_grid[0], s8_grid[-1])
231
+
232
+ marginal = posterior_norm.sum(axis=1)
233
+ marginal /= marginal.sum() + 1e-30
234
+ ax_top.fill_between(om_grid, 0.0, marginal, alpha=0.6, color="steelblue")
235
+ ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2.5)
236
+ ax_top.set_xlim(om_grid[0], om_grid[-1])
237
+ ax_top.set_ylim(0, marginal.max() * 1.12)
238
+ ax_top.tick_params(labelbottom=False)
239
+ ax_top.grid(True, alpha=0.25)
240
+
241
+ ax_sb = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
242
+ marginal_s = posterior_norm.sum(axis=0)
243
+ marginal_s /= marginal_s.sum() + 1e-30
244
+ ax_sb.fill_betweenx(s8_grid, 0.0, marginal_s, alpha=0.6, color="steelblue")
245
+ ax_sb.axhline(true_s8, color="red", linestyle="--", linewidth=2.5)
246
+ ax_sb.set_ylim(s8_grid[0], s8_grid[-1])
247
+
248
+ fig.suptitle(
249
+ f"Model Comparison (Figure 6 Style) — Test ix={test_index}",
250
+ fontsize=15,
251
+ y=0.995,
252
+ fontweight="bold",
253
+ )
254
+
255
+ save_path = output_dir / f"fig6_style_model_comparison_ix{test_index}.png"
256
+ fig.savefig(save_path, dpi=200, bbox_inches="tight")
257
+ plt.close(fig)
258
+ print(f"Saved model comparison: {save_path}")
259
+
260
+
261
+ def print_integration_guide() -> None:
262
+ example_integration = """
263
+ # Add imports next to posterior code:
264
+ from figure6_2409_style import create_figure6_style_plot
265
+ from ddpm_figure6_integration import (
266
+ integrate_figure6_with_ddpm2,
267
+ integrate_figure6_with_multi_anchor,
268
+ integrate_figure6_model_comparison,
269
+ )
270
+ """
271
+ print(example_integration.strip())
cross_model/scripts/ddpm_posterior_six_anchors.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Surrogate P(k) likelihood posteriors on ($\\Omega_m$, $\\sigma_8$) for six test anchors.
4
+
5
+ For each model:
6
+ • DDPM-2 — standard 2D marginal: sweep ($\\Omega_m$, $\\sigma_8$) while only two labels exist.
7
+ • DDPM-6 — same 2D sweep, but astrophysical / extra dimensions 2–5 are fixed in two cases:
8
+ - **extra_lower**: each of dims 2–5 fixed to the LHS **minimum** (from training labels)
9
+ - **extra_upper**: each fixed to the LHS **maximum**
10
+
11
+ The observed HI map is always the CAMELS test MAP at that anchor index.
12
+
13
+ This does not import ``compare_ddpm_models.py``; it only shares the same conventions and paths.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import argparse
19
+ import gc
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Dict, Tuple
23
+
24
+ import matplotlib
25
+
26
+ matplotlib.use("Agg")
27
+ import matplotlib.pyplot as plt
28
+ import numpy as np
29
+ import torch
30
+
31
+ MODELS_ROOT = Path(__file__).resolve().parents[1]
32
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
33
+ if str(CODE_6.resolve()) not in sys.path:
34
+ sys.path.insert(0, str(CODE_6))
35
+
36
+ import evaluate_conditional as ec # noqa: E402
37
+ import eval_model as em # noqa: E402
38
+ from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402
39
+
40
+
41
+ def _fmt_title(lab: np.ndarray) -> str:
42
+ t = np.asarray(lab, dtype=float).ravel()
43
+ if t.size <= 2:
44
+ return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}"
45
+ tail = ", ".join(f"{float(v):.3g}" for v in t[2:])
46
+ return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail
47
+
48
+
49
+ def _train_label_path(data_dir: Path) -> Path:
50
+ for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
51
+ p = data_dir / name
52
+ if p.is_file():
53
+ return p
54
+ raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
55
+
56
+
57
+ def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
58
+ """Min/max over training LHS for label dimensions 2 … 5 (indices 2–5)."""
59
+ L = np.load(_train_label_path(data_dir))
60
+ if L.shape[1] < 6:
61
+ raise ValueError(f"Expected ≥6 label columns, got {L.shape}")
62
+ lo = L[:, 2:6].min(axis=0).astype(np.float32)
63
+ hi = L[:, 2:6].max(axis=0).astype(np.float32)
64
+ return lo, hi
65
+
66
+
67
+ def posterior_weights(
68
+ obs: np.ndarray,
69
+ full: np.ndarray,
70
+ om_ax: np.ndarray,
71
+ s8_ax: np.ndarray,
72
+ lab_mean: np.ndarray,
73
+ lab_std: np.ndarray,
74
+ normalize: bool,
75
+ model: torch.nn.Module,
76
+ *,
77
+ H: int,
78
+ W: int,
79
+ device: torch.device,
80
+ grid: int,
81
+ batch_sz: int,
82
+ ddim_steps: int,
83
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
84
+ """
85
+ Returns (Wmap, OM, S8) with Wmap shaped (grid, grid); OM, S8 meshgrids (indexing='ij').
86
+ full: (ngrid, label_dim) rows on the (Ωm, σ8) grid plus any fixed tail dims.
87
+ """
88
+ ngrid = full.shape[0]
89
+ g = int(round(np.sqrt(ngrid)))
90
+ if g * g != ngrid:
91
+ raise ValueError(f"Expected square grid, got ngrid={ngrid}")
92
+
93
+ npix = int(obs.shape[-1])
94
+ dl = 25.0 / npix
95
+ dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl)
96
+ valid = dk > 0
97
+ log_pd = log_pk_observed(obs, 25.0, dk)
98
+
99
+ def weights_full() -> np.ndarray:
100
+ scores = []
101
+ for j0 in range(0, ngrid, batch_sz):
102
+ chunk = full[j0 : j0 + batch_sz]
103
+ imgs = em.sample_batch(
104
+ model,
105
+ chunk,
106
+ lab_mean,
107
+ lab_std,
108
+ normalize,
109
+ H,
110
+ W,
111
+ device,
112
+ ddim_steps,
113
+ False,
114
+ )
115
+ _, pkc = em.per_map_power_spectra_log(imgs, 25.0)
116
+ log_pg = np.log(pkc[:, valid] + 1e-30)
117
+ mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1)
118
+ scores.append(-mse / (2.0 * 0.25**2))
119
+ sc = np.concatenate(scores)
120
+ sc -= sc.max()
121
+ w = np.exp(sc).reshape(g, g)
122
+ w /= w.sum()
123
+ return w
124
+
125
+ Wmap = weights_full()
126
+ OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
127
+ return Wmap, OM, S8
128
+
129
+
130
+ def build_full_grid_2d(
131
+ labels_split: np.ndarray,
132
+ grid: int,
133
+ tail: np.ndarray | None,
134
+ lab_dim: int,
135
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
136
+ """
137
+ If tail is None (2-param model path): full has shape (grid^2, 2).
138
+ Else tail shape (4,): dims 2–5 filled with constants; dims 0,1 swept.
139
+ """
140
+ lo0 = float(labels_split[:, 0].min())
141
+ hi0 = float(labels_split[:, 0].max())
142
+ lo1 = float(labels_split[:, 1].min())
143
+ hi1 = float(labels_split[:, 1].max())
144
+ pad0 = 0.02 * (hi0 - lo0 + 1e-12)
145
+ pad1 = 0.02 * (hi1 - lo1 + 1e-12)
146
+ om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid(
147
+ grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1
148
+ )
149
+ ngrid = grid2.shape[0]
150
+ if lab_dim == 2 and tail is not None:
151
+ raise ValueError("lab_dim==2 implies no extra tail.")
152
+ out = np.zeros((ngrid, lab_dim), dtype=np.float32)
153
+ out[:, 0] = grid2[:, 0].astype(np.float32)
154
+ out[:, 1] = grid2[:, 1].astype(np.float32)
155
+ if tail is not None:
156
+ assert tail.shape == (4,)
157
+ out[:, 2:6] = tail[np.newaxis, :]
158
+ return out.astype(np.float32), om_ax, s8_ax
159
+
160
+
161
+ def plot_posterior_panel(
162
+ ax,
163
+ Wmap: np.ndarray,
164
+ OM: np.ndarray,
165
+ S8: np.ndarray,
166
+ tom: float,
167
+ ts8: float,
168
+ title: str,
169
+ *,
170
+ suptext: str | None = None,
171
+ ) -> None:
172
+ mom = float((Wmap * OM).sum())
173
+ ms8 = float((Wmap * S8).sum())
174
+ cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues")
175
+ plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
176
+ ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true")
177
+ ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean")
178
+ ax.set_xlabel(r"$\Omega_m$")
179
+ ax.set_ylabel(r"$\sigma_8$")
180
+ ax.legend(fontsize=7)
181
+ ax.set_title(title, fontsize=8)
182
+ if suptext:
183
+ ax.text(0.02, 0.98, suptext, transform=ax.transAxes, fontsize=7, va="top", color="#333")
184
+
185
+
186
+ def run_ddpm2_panels(
187
+ out_dir: Path,
188
+ images: np.ndarray,
189
+ labels: np.ndarray,
190
+ mean: np.ndarray,
191
+ std: np.ndarray,
192
+ cfg: Dict,
193
+ model: torch.nn.Module,
194
+ device: torch.device,
195
+ anchor_ix: np.ndarray,
196
+ grid: int,
197
+ ddim_steps: int,
198
+ batch_sz: int,
199
+ ) -> None:
200
+ normalize = bool(cfg.get("normalize_labels", True))
201
+ H, W = int(images.shape[-2]), int(images.shape[-1])
202
+ fig, axes = plt.subplots(2, 3, figsize=(14, 9), squeeze=False)
203
+ for k, ix in enumerate(anchor_ix.ravel()):
204
+ r, c = divmod(k, 3)
205
+ ax = axes[r, c]
206
+ obs = images[ix]
207
+ lab_t = labels[ix].astype(np.float32)
208
+ full, om_ax, s8_ax = build_full_grid_2d(labels, grid, tail=None, lab_dim=2)
209
+ Wmap, OM, S8 = posterior_weights(
210
+ obs,
211
+ full,
212
+ om_ax,
213
+ s8_ax,
214
+ mean,
215
+ std,
216
+ normalize,
217
+ model,
218
+ H=H,
219
+ W=W,
220
+ device=device,
221
+ grid=grid,
222
+ batch_sz=batch_sz,
223
+ ddim_steps=ddim_steps,
224
+ )
225
+ tom, ts8 = float(lab_t[0]), float(lab_t[1])
226
+ plot_posterior_panel(
227
+ ax,
228
+ Wmap,
229
+ OM,
230
+ S8,
231
+ tom,
232
+ ts8,
233
+ f"test ix={ix}\n{_fmt_title(lab_t)}",
234
+ )
235
+ plt.suptitle(
236
+ r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ — six CAMELS anchors",
237
+ fontsize=11,
238
+ y=0.995,
239
+ )
240
+ plt.tight_layout(rect=(0, 0, 1, 0.97))
241
+ p = out_dir / "posterior_six_anchors_ddpm2.png"
242
+ fig.savefig(p, dpi=170, bbox_inches="tight")
243
+ plt.close(fig)
244
+ print("Saved", p)
245
+
246
+
247
+ def run_ddpm6_case(
248
+ out_dir: Path,
249
+ *,
250
+ suffix: str,
251
+ tail_fixed: np.ndarray,
252
+ tail_name: str,
253
+ images: np.ndarray,
254
+ labels: np.ndarray,
255
+ mean: np.ndarray,
256
+ std: np.ndarray,
257
+ cfg: Dict,
258
+ model: torch.nn.Module,
259
+ device: torch.device,
260
+ anchor_ix: np.ndarray,
261
+ grid: int,
262
+ ddim_steps: int,
263
+ batch_sz: int,
264
+ ) -> None:
265
+ normalize = bool(cfg.get("normalize_labels", True))
266
+ H, W = int(images.shape[-2]), int(images.shape[-1])
267
+ fig, axes = plt.subplots(2, 3, figsize=(14, 9), squeeze=False)
268
+ for k, ix in enumerate(anchor_ix.ravel()):
269
+ r, c = divmod(k, 3)
270
+ ax = axes[r, c]
271
+ obs = images[ix]
272
+ lab_t = labels[ix].astype(np.float32)
273
+ full, om_ax, s8_ax = build_full_grid_2d(labels, grid, tail=tail_fixed, lab_dim=6)
274
+ Wmap, OM, S8 = posterior_weights(
275
+ obs,
276
+ full,
277
+ om_ax,
278
+ s8_ax,
279
+ mean,
280
+ std,
281
+ normalize,
282
+ model,
283
+ H=H,
284
+ W=W,
285
+ device=device,
286
+ grid=grid,
287
+ batch_sz=batch_sz,
288
+ ddim_steps=ddim_steps,
289
+ )
290
+ tom, ts8 = float(lab_t[0]), float(lab_t[1])
291
+ plot_posterior_panel(
292
+ ax,
293
+ Wmap,
294
+ OM,
295
+ S8,
296
+ tom,
297
+ ts8,
298
+ f"test ix={ix}",
299
+ suptext=tail_name,
300
+ )
301
+ plt.suptitle(
302
+ r"DDPM-6 — $(\Omega_m,\,\sigma_8)$ sweep; dims 2–5 fixed (" + tail_name + ")",
303
+ fontsize=11,
304
+ y=0.995,
305
+ )
306
+ plt.tight_layout(rect=(0, 0, 1, 0.96))
307
+ p = out_dir / f"posterior_six_anchors_ddpm6_{suffix}.png"
308
+ fig.savefig(p, dpi=170, bbox_inches="tight")
309
+ plt.close(fig)
310
+ print("Saved", p)
311
+
312
+
313
+ def load_model(bundle_args: Path, ckpt: Path, device: torch.device):
314
+ cfg = ec.load_training_config(str(bundle_args))
315
+ model = ec.build_model(cfg, device)
316
+ ec.load_checkpoint(model, str(ckpt), device)
317
+ model.eval()
318
+ return model, cfg
319
+
320
+
321
+ def main() -> None:
322
+ p = argparse.ArgumentParser(
323
+ description="Six-anchor surrogate posteriors: DDPM-2 and DDPM-6 (extra dims min vs max)."
324
+ )
325
+ p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_posterior_six_anchors_out")
326
+ p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
327
+ p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
328
+ p.add_argument(
329
+ "--bundle-2param",
330
+ type=Path,
331
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
332
+ )
333
+ p.add_argument(
334
+ "--bundle-6param",
335
+ type=Path,
336
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
337
+ )
338
+ p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
339
+ p.add_argument("--grid", type=int, default=14, help="Grid points per Ωm–σ8 axis.")
340
+ p.add_argument("--ddim-steps", type=int, default=50)
341
+ p.add_argument("--batch-size", type=int, default=8)
342
+ p.add_argument(
343
+ "--ddpm2-only",
344
+ action="store_true",
345
+ help="Only compute DDPM-2 figure (skip loading DDPM-6).",
346
+ )
347
+ p.add_argument(
348
+ "--ddpm6-only",
349
+ action="store_true",
350
+ help="Only compute DDPM-6 figures (skip loading DDPM-2).",
351
+ )
352
+ args = p.parse_args()
353
+
354
+ out_dir = Path(args.output_dir).resolve()
355
+ out_dir.mkdir(parents=True, exist_ok=True)
356
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
357
+ print("device:", device)
358
+
359
+ data2 = Path(args.data_2param)
360
+ data6 = Path(args.data_6param)
361
+
362
+ imgs2, lab2 = ec.load_split(data2, args.split)
363
+ imgs6, lab6 = ec.load_split(data6, args.split)
364
+
365
+ n = min(len(lab2), len(lab6))
366
+ anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
367
+
368
+ low_tail, hi_tail = tail_lhs_bounds(data6)
369
+ print("LHS tails (dims 2–5): min", low_tail, "max", hi_tail)
370
+
371
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
372
+ args_json_2 = args.bundle_2param / "args.json"
373
+ ck6 = args.bundle_6param / "best_model.pt"
374
+ args_json_6 = args.bundle_6param / "args.json"
375
+
376
+ mean2, std2 = ec.load_label_stats(data2)
377
+ mean6, std6 = ec.load_label_stats(data6)
378
+
379
+ if args.ddpm6_only and args.ddpm2_only:
380
+ raise SystemExit("Use at most one of --ddpm2-only / --ddpm6-only.")
381
+
382
+ if not args.ddpm6_only:
383
+ print(">>> DDPM-2 (six anchors)...")
384
+ model2, cfg2 = load_model(args_json_2, ck2, device)
385
+ run_ddpm2_panels(
386
+ out_dir,
387
+ imgs2,
388
+ lab2,
389
+ mean2,
390
+ std2,
391
+ cfg2,
392
+ model2,
393
+ device,
394
+ anchor_ix,
395
+ args.grid,
396
+ args.ddim_steps,
397
+ args.batch_size,
398
+ )
399
+ del model2
400
+ gc.collect()
401
+ if torch.cuda.is_available():
402
+ torch.cuda.empty_cache()
403
+
404
+ if not args.ddpm2_only:
405
+ print(">>> DDPM-6 — extra dims at LHS minima (six anchors)...")
406
+ model6, cfg6 = load_model(args_json_6, ck6, device)
407
+ run_ddpm6_case(
408
+ out_dir,
409
+ suffix="extra_lower",
410
+ tail_fixed=low_tail,
411
+ tail_name="min",
412
+ images=imgs6,
413
+ labels=lab6,
414
+ mean=mean6,
415
+ std=std6,
416
+ cfg=cfg6,
417
+ model=model6,
418
+ device=device,
419
+ anchor_ix=anchor_ix,
420
+ grid=args.grid,
421
+ ddim_steps=args.ddim_steps,
422
+ batch_sz=args.batch_size,
423
+ )
424
+ print(">>> DDPM-6 — extra dims at LHS maxima (six anchors)...")
425
+ run_ddpm6_case(
426
+ out_dir,
427
+ suffix="extra_upper",
428
+ tail_fixed=hi_tail,
429
+ tail_name="max",
430
+ images=imgs6,
431
+ labels=lab6,
432
+ mean=mean6,
433
+ std=std6,
434
+ cfg=cfg6,
435
+ model=model6,
436
+ device=device,
437
+ anchor_ix=anchor_ix,
438
+ grid=args.grid,
439
+ ddim_steps=args.ddim_steps,
440
+ batch_sz=args.batch_size,
441
+ )
442
+ del model6
443
+ gc.collect()
444
+ if torch.cuda.is_available():
445
+ torch.cuda.empty_cache()
446
+
447
+ print(f"Done. Outputs in {out_dir}")
448
+
449
+
450
+ if __name__ == "__main__":
451
+ main()
cross_model/scripts/ddpm_triangle_integration.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Surrogate posterior on $(\\Omega_m, \\sigma_8)$ → triangle/MCMC-style chains for one test map.
4
+
5
+ Loads the same surrogate likelihood used in ``ddpm_posterior_six_anchors``, resamples discrete
6
+ posterior masses to ``--n-hist`` correlated $(\\Omega_m,\\sigma_8)$ pairs, and writes ``.npz``.
7
+
8
+ DDPM-2: sweeps $(\\Omega_m,\\sigma_8)$.
9
+ DDPM-6: dims 2–5 fixed per ``--six-tail-mode`` (``truth`` uses the test-map labels 2–5; ``min``/``max``
10
+ use LHS extrema from training labels).
11
+
12
+ If you replace this file with a copy from your machine (Downloads), keep argparse compatible or wrap it.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ _SCRIPTS = Path(__file__).resolve().parent
25
+ if str(_SCRIPTS) not in sys.path:
26
+ sys.path.insert(0, str(_SCRIPTS))
27
+
28
+ import ddpm_posterior_six_anchors as dps # noqa: E402
29
+
30
+ MODELS_ROOT = Path(__file__).resolve().parents[1]
31
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
32
+ if str(CODE_6.resolve()) not in sys.path:
33
+ sys.path.insert(0, str(CODE_6.resolve()))
34
+
35
+ import evaluate_conditional as ec # noqa: E402
36
+
37
+
38
+ def _tail_vec(
39
+ mode: str,
40
+ lab_full: np.ndarray,
41
+ data6: Path,
42
+ ) -> np.ndarray | None:
43
+ if lab_full.size <= 2:
44
+ return None
45
+ if mode == "truth":
46
+ return lab_full[2:6].astype(np.float32)
47
+ low, hi = dps.tail_lhs_bounds(data6)
48
+ if mode == "min":
49
+ return low
50
+ if mode == "max":
51
+ return hi
52
+ raise ValueError("six-tail-mode must be truth|min|max")
53
+
54
+
55
+ def main() -> None:
56
+ p = argparse.ArgumentParser(description="DDPM surrogate posterior → resampled Ωm σ8 chains (.npz).")
57
+ p.add_argument(
58
+ "--label-dim",
59
+ type=int,
60
+ choices=[2, 6],
61
+ required=True,
62
+ help="Which model to use.",
63
+ )
64
+ p.add_argument(
65
+ "--bundle",
66
+ type=Path,
67
+ default=None,
68
+ help="Checkpoint bundle dir with args.json (default: notebook_model_weights/<2|6>).",
69
+ )
70
+ p.add_argument(
71
+ "--checkpoint-name",
72
+ type=str,
73
+ default=None,
74
+ help="Checkpoint file under bundle (defaults: DDPM2 epoch200, DDPM6 best_model).",
75
+ )
76
+ p.add_argument(
77
+ "--data-dir",
78
+ type=Path,
79
+ default=None,
80
+ help="LH data dir matching label_dim (default: params_2 vs params_6).",
81
+ )
82
+ p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
83
+ p.add_argument("--test-index", type=int, default=56, help="Index into split for CAMELS observation.")
84
+ p.add_argument("--grid", type=int, default=14)
85
+ p.add_argument("--ddim-steps", type=int, default=50)
86
+ p.add_argument("--batch-size", type=int, default=8)
87
+ p.add_argument(
88
+ "--n-hist",
89
+ type=int,
90
+ default=10_000,
91
+ help="Resampled posterior pairs (with replacement).",
92
+ )
93
+ p.add_argument(
94
+ "--six-tail-mode",
95
+ type=str,
96
+ default="truth",
97
+ choices=["truth", "min", "max"],
98
+ help="Applies only to label_dim==6 — how dims 2–5 are fixed.",
99
+ )
100
+ p.add_argument(
101
+ "--output",
102
+ "-o",
103
+ type=Path,
104
+ required=True,
105
+ help="Output .npz path.",
106
+ )
107
+ p.add_argument("--seed", type=int, default=42)
108
+ args = p.parse_args()
109
+
110
+ ld = args.label_dim
111
+ if ld == 2:
112
+ data_dir = args.data_dir or Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2")
113
+ bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "2param_epoch200"
114
+ ck_name = args.checkpoint_name or "checkpoint_epoch_200.pt"
115
+ else:
116
+ data_dir = args.data_dir or Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6")
117
+ bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "6param_best"
118
+ ck_name = args.checkpoint_name or "best_model.pt"
119
+
120
+ rng = np.random.default_rng(args.seed)
121
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ imgs, labs = ec.load_split(data_dir, args.split)
123
+ ix = int(args.test_index)
124
+ if not (0 <= ix < len(labs)):
125
+ raise SystemExit(f"test-index {ix} out of range for split ({len(labs)} rows)")
126
+ lab_t = labs[ix].astype(np.float64)
127
+ obs = imgs[ix]
128
+ ckpt = bundle / ck_name
129
+ args_json = bundle / "args.json"
130
+ mean, std = ec.load_label_stats(data_dir)
131
+
132
+ tail = None
133
+ if ld == 6:
134
+ lab6 = labs[ix].astype(np.float64)
135
+ if lab6.shape[0] != 6:
136
+ raise SystemExit("--label-dim 6 requires labels with 6 columns in data-dir")
137
+ tail = _tail_vec(args.six_tail_mode, lab6, Path(data_dir))
138
+ model, cfg = dps.load_model(args_json, ckpt, device)
139
+ normalize = bool(cfg.get("normalize_labels", True))
140
+ H = int(obs.shape[-2])
141
+ W = int(obs.shape[-1])
142
+ gsz = args.grid
143
+
144
+ full, om_ax, s8_ax = dps.build_full_grid_2d(labs, gsz, tail=tail, lab_dim=ld)
145
+ Wmap, OM, S8 = dps.posterior_weights(
146
+ obs,
147
+ full,
148
+ om_ax,
149
+ s8_ax,
150
+ mean,
151
+ std,
152
+ normalize,
153
+ model,
154
+ H=H,
155
+ W=W,
156
+ device=device,
157
+ grid=gsz,
158
+ batch_sz=args.batch_size,
159
+ ddim_steps=args.ddim_steps,
160
+ )
161
+
162
+ wflat = np.clip(Wmap.ravel().astype(np.float64), 0.0, None)
163
+ if wflat.sum() <= 0:
164
+ raise RuntimeError("Posterior masses collapsed to zero.")
165
+ wflat /= wflat.sum()
166
+ omapflat = OM.ravel()
167
+ s8flat = S8.ravel()
168
+ draws = rng.choice(np.arange(len(wflat)), size=args.n_hist, replace=True, p=wflat)
169
+ samp_om = omapflat[draws].astype(np.float64)
170
+ samp_s8 = s8flat[draws].astype(np.float64)
171
+
172
+ out = Path(args.output).resolve()
173
+ out.parent.mkdir(parents=True, exist_ok=True)
174
+ tag = f"ddpm{ld}_{args.six_tail_mode}" if ld == 6 else "ddpm2"
175
+ np.savez_compressed(
176
+ out,
177
+ omega_m=samp_om,
178
+ sigma_8=samp_s8,
179
+ samples=np.column_stack([samp_om, samp_s8]),
180
+ truth_Omega_m=float(lab_t[0]),
181
+ truth_sigma_8=float(lab_t[1]),
182
+ posterior_map=Wmap,
183
+ OM=OM,
184
+ S8=S8,
185
+ index=np.array(ix, dtype=np.int32),
186
+ label_dim=np.array(ld, dtype=np.int16),
187
+ meta_tag=np.array(tag, dtype="U128"),
188
+ six_tail_mode=np.array(args.six_tail_mode if ld == 6 else "", dtype="U16"),
189
+ )
190
+ print("Saved", out, "pairs:", args.n_hist, "device:", device)
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
cross_model/scripts/figure6_2409_style.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style):
3
+ main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Tuple
9
+
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ from matplotlib import gridspec as mgs
13
+
14
+ from sigma_contour_utils import compute_sigma_levels
15
+
16
+
17
+ def create_figure6_style_plot(
18
+ Wmap: np.ndarray,
19
+ om_ax: np.ndarray,
20
+ s8_ax: np.ndarray,
21
+ *,
22
+ true_param1: float,
23
+ true_param2: float,
24
+ param1_label: str = r"$\Omega_m$",
25
+ param2_label: str = r"$\sigma_8$",
26
+ title: str = "",
27
+ show_profile: bool = False,
28
+ figsize: Tuple[float, float] = (10, 10),
29
+ ):
30
+ """
31
+ Parameters
32
+ ----------
33
+ Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij').
34
+ om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``.
35
+ show_profile :
36
+ False → 1D marginals are sums (*marginal*) over the other parameter.
37
+ True → 1D marginals are max (*profile*) over the other parameter (then normalized).
38
+ """
39
+ p = np.asarray(Wmap, dtype=np.float64)
40
+ p = p / (p.sum() + 1e-30)
41
+ P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij")
42
+
43
+ if show_profile:
44
+ m1 = np.max(p, axis=1)
45
+ m2 = np.max(p, axis=0)
46
+ else:
47
+ m1 = p.sum(axis=1)
48
+ m2 = p.sum(axis=0)
49
+ m1 = np.asarray(m1, dtype=np.float64)
50
+ m2 = np.asarray(m2, dtype=np.float64)
51
+ m1 /= m1.max() + 1e-30
52
+ m2 /= m2.max() + 1e-30
53
+
54
+ fig = plt.figure(figsize=figsize)
55
+ gs = mgs.GridSpec(
56
+ nrows=2,
57
+ ncols=2,
58
+ figure=fig,
59
+ width_ratios=[4.0, 1.05],
60
+ height_ratios=[1.05, 4.0],
61
+ wspace=0.035,
62
+ hspace=0.035,
63
+ left=0.12,
64
+ right=0.98,
65
+ bottom=0.1,
66
+ top=0.92,
67
+ )
68
+ ax_main = fig.add_subplot(gs[1, 0])
69
+ ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
70
+ ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main)
71
+ ax_empty = fig.add_subplot(gs[0, 1])
72
+ ax_empty.axis("off")
73
+
74
+ lvl = compute_sigma_levels(p, [0.683, 0.954])
75
+ ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88)
76
+ if len(set(lvl)) >= 2:
77
+ ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5])
78
+ ax_main.scatter(
79
+ true_param1,
80
+ true_param2,
81
+ s=120,
82
+ c="red",
83
+ marker="x",
84
+ linewidths=2.8,
85
+ zorder=15,
86
+ label="true",
87
+ )
88
+ ax_main.set_xlabel(param1_label, fontsize=13)
89
+ ax_main.set_ylabel(param2_label, fontsize=13)
90
+ ax_main.grid(True, alpha=0.28)
91
+ ax_main.legend(fontsize=8, loc="upper right")
92
+
93
+ ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue")
94
+ ax_top.axvline(true_param1, color="red", ls="--", lw=2.0)
95
+ ax_top.set_ylim(0.0, float(np.max(m1) * 1.12))
96
+ ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10)
97
+ ax_top.tick_params(labelbottom=False)
98
+ ax_top.grid(True, alpha=0.25)
99
+
100
+ ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue")
101
+ ax_r.axhline(true_param2, color="red", ls="--", lw=2.0)
102
+ ax_r.set_xlim(0.0, float(np.max(m2) * 1.12))
103
+ ax_r.set_xlabel("$P$", fontsize=10)
104
+ ax_r.tick_params(labelleft=False)
105
+ ax_r.grid(True, alpha=0.25)
106
+
107
+ kind = "Profile" if show_profile else "Marginal"
108
+ fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98)
109
+
110
+ plt.setp(ax_top.get_xticklabels(), visible=False)
111
+
112
+ return fig
113
+
114
+
115
+ def create_comparison_marginal_vs_profile(
116
+ Wmap: np.ndarray,
117
+ om_ax: np.ndarray,
118
+ s8_ax: np.ndarray,
119
+ *,
120
+ true_param1: float,
121
+ true_param2: float,
122
+ param1_label: str = r"$\Omega_m$",
123
+ param2_label: str = r"$\sigma_8$",
124
+ title: str = "",
125
+ figsize: Tuple[float, float] = (10, 4.2),
126
+ ):
127
+ """Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes."""
128
+ p = np.asarray(Wmap, dtype=np.float64)
129
+ p /= p.sum() + 1e-30
130
+ marg_om = p.sum(axis=1)
131
+ marg_s8 = p.sum(axis=0)
132
+ prof_om = np.max(p, axis=1)
133
+ prof_s8 = np.max(p, axis=0)
134
+ marg_om /= marg_om.sum() + 1e-30
135
+ marg_s8 /= marg_s8.sum() + 1e-30
136
+ prof_om /= prof_om.max() + 1e-30
137
+ prof_s8 /= prof_s8.max() + 1e-30
138
+
139
+ fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False)
140
+ for ax, xaxis, marg, prof, xlab, xv in zip(
141
+ axes,
142
+ (om_ax, s8_ax),
143
+ (marg_om, marg_s8),
144
+ (prof_om, prof_s8),
145
+ (param1_label, param2_label),
146
+ (true_param1, true_param2),
147
+ ):
148
+ ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal")
149
+ ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile")
150
+ ax.axvline(xv, color="crimson", ls=":", lw=1.8)
151
+ ax.set_xlabel(xlab, fontsize=12)
152
+ ax.set_ylabel("norm. density", fontsize=10)
153
+ ax.legend(fontsize=9)
154
+ ax.grid(True, alpha=0.3)
155
+ fig.suptitle(title, fontsize=12, fontweight="bold")
156
+ fig.tight_layout(rect=(0, 0, 1, 0.93))
157
+ return fig
cross_model/scripts/run_ddpm_comparison.sh ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=24:00:00
8
+ #SBATCH --job-name=ddpm_compare
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-compare-%j.out
11
+ #SBATCH --error=slurm-ddpm-compare-%j.err
12
+
13
+ # DDPM-2 vs DDPM-6 comparison (same cluster layout as 6-param training — see reference below).
14
+ #
15
+ # Reference training script (Slurm + module + paths pattern):
16
+ # /scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh
17
+ #
18
+ # Submit from anywhere:
19
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
20
+ #
21
+ # Extra CLI args for compare_ddpm_models.py pass through, e.g. LHS off:
22
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh --skip-lhs-r2
23
+ #
24
+ # Override output dir (optional):
25
+ # sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/ddpm_comparison_out_ab \
26
+ # /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
27
+ #
28
+ # Optional: override DDPM-2 train/val for the combined loss plot (default: bundled JSON in Models/scripts/):
29
+ # sbatch --export=SLURM_2PARAM=/path/to/slurm-2param-ddpm.out \
30
+ # /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
31
+ #
32
+ # Interactive (same module as training script):
33
+ # module load python/miniconda3-py3.12-usr
34
+ # bash .../run_ddpm_comparison.sh --skip-lhs-r2
35
+
36
+ set -euo pipefail
37
+
38
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
39
+ cd "$ROOT"
40
+
41
+ module load python/miniconda3-py3.12-usr
42
+
43
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_comparison_out}"
44
+
45
+ echo "==============================================="
46
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
47
+ echo "Job Name: ${SLURM_JOB_NAME:-run_ddpm_comparison}"
48
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
49
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
50
+ echo "Starting Time: $(date)"
51
+ echo "Comparison output: ${OUT}"
52
+ echo "Reference Slurm recipe: april_26/ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh"
53
+ echo "==============================================="
54
+
55
+ PY_ARGS=(python "$ROOT/scripts/compare_ddpm_models.py" --output-dir "$OUT")
56
+ if [[ -n "${SLURM_2PARAM:-}" ]]; then
57
+ PY_ARGS+=(--slurm-2param "${SLURM_2PARAM}")
58
+ fi
59
+ PY_ARGS+=("$@")
60
+
61
+ "${PY_ARGS[@]}"
62
+
63
+ echo "==============================================="
64
+ echo "Artifacts -> ${OUT}"
65
+ echo "Finished at: $(date)"
66
+ echo "==============================================="
cross_model/scripts/run_ddpm_figure6.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=24:00:00
8
+ #SBATCH --job-name=ddpm_fig6
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-figure6-%j.out
11
+ #SBATCH --error=slurm-ddpm-figure6-%j.err
12
+
13
+ # Figure 6 style (arXiv:2409.09101-inspired) surrogate posteriors for DDPM-2 / DDPM-6.
14
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_figure6.sh
15
+ # sbatch --export=OUTPUT_DIR=/path/to/out,TEST_INDEX=42 .../run_ddpm_figure6.sh --no-six-grid
16
+ #
17
+ set -euo pipefail
18
+
19
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
20
+ cd "$ROOT"
21
+ module load python/miniconda3-py3.12-usr
22
+
23
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_figure6_out}"
24
+ IDX="${TEST_INDEX:-56}"
25
+
26
+ echo "Job=${SLURM_JOB_ID:-local} OUT=${OUT} TEST_INDEX=${IDX}"
27
+ python "${ROOT}/scripts/run_ddpm_figure6_suite.py" --output-dir "${OUT}" --test-index "${IDX}" "$@"
cross_model/scripts/run_ddpm_figure6_suite.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Compute surrogate posteriors and emit Figure-6 style figures (arXiv:2409.09101-inspired).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import gc
10
+ import sys
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import torch
15
+
16
+ _SCRIPTS = Path(__file__).resolve().parent
17
+ MODELS_ROOT = Path(__file__).resolve().parents[1]
18
+ CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
19
+ if str(_SCRIPTS) not in sys.path:
20
+ sys.path.insert(0, str(_SCRIPTS))
21
+ if str(CODE_6.resolve()) not in sys.path:
22
+ sys.path.insert(0, str(CODE_6.resolve()))
23
+
24
+ import evaluate_conditional as ec # noqa: E402
25
+ import ddpm_posterior_six_anchors as dps # noqa: E402
26
+
27
+ from ddpm_figure6_integration import ( # noqa: E402
28
+ integrate_figure6_model_comparison,
29
+ integrate_figure6_with_ddpm2,
30
+ integrate_figure6_with_multi_anchor,
31
+ print_integration_guide,
32
+ )
33
+
34
+
35
+ def main() -> None:
36
+ p = argparse.ArgumentParser(description="DDPM Figure-6 style posterior suite.")
37
+ p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_figure6_out")
38
+ p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
39
+ p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
40
+ p.add_argument(
41
+ "--bundle-2param",
42
+ type=Path,
43
+ default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
44
+ )
45
+ p.add_argument(
46
+ "--bundle-6param",
47
+ type=Path,
48
+ default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
49
+ )
50
+ p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
51
+ p.add_argument("--test-index", type=int, default=56, help="Index for single comparison + per-map fig6.")
52
+ p.add_argument("--grid", type=int, default=14)
53
+ p.add_argument("--ddim-steps", type=int, default=50)
54
+ p.add_argument("--batch-size", type=int, default=8)
55
+ p.add_argument(
56
+ "--six-anchors-only",
57
+ action="store_true",
58
+ help="Only 2×3 multi-anchor plots (skip triple model comparison at --test-index).",
59
+ )
60
+ p.add_argument(
61
+ "--no-six-grid",
62
+ action="store_true",
63
+ help="Skip multi-anchor 2×3 panels.",
64
+ )
65
+ p.add_argument(
66
+ "--no-single-fig6",
67
+ action="store_true",
68
+ help="Skip per-map marginal/profile for test-index on DDPM-2 and DDPM-6 (truth tail).",
69
+ )
70
+ p.add_argument(
71
+ "--guide",
72
+ action="store_true",
73
+ help="Print markdown-style integration notes and exit.",
74
+ )
75
+ args = p.parse_args()
76
+
77
+ if args.guide:
78
+ print_integration_guide()
79
+ return
80
+
81
+ out = Path(args.output_dir).resolve()
82
+ out.mkdir(parents=True, exist_ok=True)
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+ print("device:", device)
86
+
87
+ data2 = Path(args.data_2param)
88
+ data6 = Path(args.data_6param)
89
+ imgs2, lab2 = ec.load_split(data2, args.split)
90
+ imgs6, lab6 = ec.load_split(data6, args.split)
91
+ n = min(len(lab2), len(lab6))
92
+ anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
93
+
94
+ low_tail, hi_tail = dps.tail_lhs_bounds(data6)
95
+
96
+ ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
97
+ aj2 = args.bundle_2param / "args.json"
98
+ ck6 = args.bundle_6param / "best_model.pt"
99
+ aj6 = args.bundle_6param / "args.json"
100
+
101
+ mean2, std2 = ec.load_label_stats(data2)
102
+ mean6, std6 = ec.load_label_stats(data6)
103
+
104
+ ix = int(args.test_index)
105
+ if not (0 <= ix < n):
106
+ raise SystemExit(f"--test-index {ix} invalid (max {n - 1})")
107
+
108
+ lab_box = lab6[:, :2].copy()
109
+
110
+ if not args.six_anchors_only:
111
+ print(">>> Loading models for ix=", ix, "...")
112
+ m2, c2 = dps.load_model(aj2, ck2, device)
113
+ m6, c6 = dps.load_model(aj6, ck6, device)
114
+ normalize2 = bool(c2.get("normalize_labels", True))
115
+ normalize6 = bool(c6.get("normalize_labels", True))
116
+
117
+ obs2 = imgs2[ix]
118
+ obs6 = imgs6[ix]
119
+ lt2 = lab2[ix].astype(np.float64)
120
+ lt6 = lab6[ix].astype(np.float64)
121
+ ta2om, ta2s8 = float(lt2[0]), float(lt2[1])
122
+ tom, ts8 = float(lt6[0]), float(lt6[1])
123
+
124
+ full2, om_ax, s8_ax = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
125
+ Wm2, _, _ = dps.posterior_weights(
126
+ obs2,
127
+ full2,
128
+ om_ax,
129
+ s8_ax,
130
+ mean2,
131
+ std2,
132
+ normalize2,
133
+ m2,
134
+ H=int(obs2.shape[-2]),
135
+ W=int(obs2.shape[-1]),
136
+ device=device,
137
+ grid=args.grid,
138
+ batch_sz=args.batch_size,
139
+ ddim_steps=args.ddim_steps,
140
+ )
141
+
142
+ full6truth, om6, s86 = dps.build_full_grid_2d(
143
+ lab6, args.grid, tail=lab6[ix, 2:6].astype(np.float32), lab_dim=6
144
+ )
145
+ Wm6t, _, _ = dps.posterior_weights(
146
+ obs6,
147
+ full6truth,
148
+ om6,
149
+ s86,
150
+ mean6,
151
+ std6,
152
+ normalize6,
153
+ m6,
154
+ H=int(obs6.shape[-2]),
155
+ W=int(obs6.shape[-1]),
156
+ device=device,
157
+ grid=args.grid,
158
+ batch_sz=args.batch_size,
159
+ ddim_steps=args.ddim_steps,
160
+ )
161
+
162
+ full6lo, om_b, s8_b = dps.build_full_grid_2d(lab6, args.grid, tail=low_tail, lab_dim=6)
163
+ Wm6lo, _, _ = dps.posterior_weights(
164
+ obs6,
165
+ full6lo,
166
+ om_b,
167
+ s8_b,
168
+ mean6,
169
+ std6,
170
+ normalize6,
171
+ m6,
172
+ H=int(obs6.shape[-2]),
173
+ W=int(obs6.shape[-1]),
174
+ device=device,
175
+ grid=args.grid,
176
+ batch_sz=args.batch_size,
177
+ ddim_steps=args.ddim_steps,
178
+ )
179
+
180
+ full6hi, om_c, s8_c = dps.build_full_grid_2d(lab6, args.grid, tail=hi_tail, lab_dim=6)
181
+ Wm6hi, _, _ = dps.posterior_weights(
182
+ obs6,
183
+ full6hi,
184
+ om_c,
185
+ s8_c,
186
+ mean6,
187
+ std6,
188
+ normalize6,
189
+ m6,
190
+ H=int(obs6.shape[-2]),
191
+ W=int(obs6.shape[-1]),
192
+ device=device,
193
+ grid=args.grid,
194
+ batch_sz=args.batch_size,
195
+ ddim_steps=args.ddim_steps,
196
+ )
197
+
198
+ if not (np.allclose(om_ax, om_b, rtol=0, atol=1e-12) and np.allclose(s8_ax, s86)):
199
+ print("Warning: Ωm–σ8 grids differ between setups; plotting uses DDPM-2 Ωm/σ8 axes.")
200
+
201
+ integrate_figure6_model_comparison(
202
+ {
203
+ "DDPM-2": Wm2,
204
+ "DDPM-6 (truth-tail)": Wm6t,
205
+ "DDPM-6 (min-tail)": Wm6lo,
206
+ "DDPM-6 (max-tail)": Wm6hi,
207
+ },
208
+ om_ax,
209
+ s8_ax,
210
+ tom,
211
+ ts8,
212
+ ix,
213
+ out,
214
+ )
215
+
216
+ if not args.no_single_fig6:
217
+ integrate_figure6_with_ddpm2(Wm2, om_ax, s8_ax, ta2om, ta2s8, ix, out, model_name="DDPM-2")
218
+ integrate_figure6_with_ddpm2(Wm6t, om_ax, s8_ax, tom, ts8, ix, out, model_name="DDPM-6-truth")
219
+
220
+ del m2, m6
221
+ gc.collect()
222
+ if torch.cuda.is_available():
223
+ torch.cuda.empty_cache()
224
+
225
+ # --- Six anchors: multi grids for DDPM-2 + DDPM-6 truth tail ---
226
+ if not args.no_six_grid:
227
+ print(">>> Six-anchor Figure 6 grids...")
228
+ model2, cfg2 = dps.load_model(aj2, ck2, device)
229
+ model6, cfg6 = dps.load_model(aj6, ck6, device)
230
+ nz2 = bool(cfg2.get("normalize_labels", True))
231
+ nz6 = bool(cfg6.get("normalize_labels", True))
232
+
233
+ post2: list[np.ndarray] = []
234
+ post6: list[np.ndarray] = []
235
+ truths: list[tuple[float, float]] = []
236
+ indices: list[int] = []
237
+
238
+ for k, jx in enumerate(anchor_ix.ravel()):
239
+ indices.append(int(jx))
240
+ o2 = imgs2[jx]
241
+ lb2 = lab2[jx].astype(np.float64)
242
+ f2, oa, sa = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
243
+ W2, _, _ = dps.posterior_weights(
244
+ o2,
245
+ f2,
246
+ oa,
247
+ sa,
248
+ mean2,
249
+ std2,
250
+ nz2,
251
+ model2,
252
+ H=int(o2.shape[-2]),
253
+ W=int(o2.shape[-1]),
254
+ device=device,
255
+ grid=args.grid,
256
+ batch_sz=args.batch_size,
257
+ ddim_steps=args.ddim_steps,
258
+ )
259
+ post2.append(W2)
260
+
261
+ o6 = imgs6[jx]
262
+ lb6 = lab6[jx]
263
+ tail_truth = lb6.astype(np.float32)[2:6]
264
+ f6, oa6, sa6 = dps.build_full_grid_2d(lab6, args.grid, tail=tail_truth, lab_dim=6)
265
+ W6, _, _ = dps.posterior_weights(
266
+ o6,
267
+ f6,
268
+ oa6,
269
+ sa6,
270
+ mean6,
271
+ std6,
272
+ nz6,
273
+ model6,
274
+ H=int(o6.shape[-2]),
275
+ W=int(o6.shape[-1]),
276
+ device=device,
277
+ grid=args.grid,
278
+ batch_sz=args.batch_size,
279
+ ddim_steps=args.ddim_steps,
280
+ )
281
+ post6.append(W6)
282
+ truths.append((float(lb2[0]), float(lb2[1])))
283
+
284
+ integrate_figure6_with_multi_anchor(
285
+ post2,
286
+ oa,
287
+ sa,
288
+ truths,
289
+ indices,
290
+ out,
291
+ model_name="DDPM-2",
292
+ )
293
+
294
+ truths6 = [(float(lab6[int(j)][0]), float(lab6[int(j)][1])) for j in anchor_ix]
295
+
296
+ integrate_figure6_with_multi_anchor(
297
+ post6,
298
+ oa6,
299
+ sa6,
300
+ truths6,
301
+ indices,
302
+ out,
303
+ model_name="DDPM-6-truth-tail",
304
+ )
305
+
306
+ del model2, model6
307
+ gc.collect()
308
+ if torch.cuda.is_available():
309
+ torch.cuda.empty_cache()
310
+
311
+ print(f"Done. Outputs in {out}")
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()
cross_model/scripts/run_ddpm_posterior_corrected.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --job-name=ddpm_post_corr
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-post-corr-%j.out
11
+ #SBATCH --error=slurm-ddpm-post-corr-%j.err
12
+
13
+ # Prior / likelihood / posterior visualization pipeline (ddpm_posterior_corrected.py).
14
+ # Separate from poster.py — default output dir is ddpm_posterior_corrected_fullviz_out.
15
+ #
16
+ # Submit:
17
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_corrected.sh
18
+ #
19
+ # Extra args pass through to the Python script:
20
+ # sbatch .../run_ddpm_posterior_corrected.sh --ddpm2-only --grid 20 --n-ddpm-samples 4 --no-ppc
21
+ #
22
+ # Override dirs:
23
+ # sbatch --export=OUTPUT_DIR=/path/to/out,CUSTOM_LOG=/path/run.log \\
24
+ # .../run_ddpm_posterior_corrected.sh
25
+ #
26
+ # Interactive:
27
+ # module load python/miniconda3-py3.12-usr
28
+ # bash .../run_ddpm_posterior_corrected.sh --help
29
+
30
+ set -euo pipefail
31
+
32
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
33
+ cd "$ROOT"
34
+
35
+ module load python/miniconda3-py3.12-usr
36
+
37
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_corrected_fullviz_out}"
38
+ mkdir -p "$OUT"
39
+
40
+ # Copy of stdout/stderr for progress (also appears in Slurm .out/.err)
41
+ RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
42
+
43
+ echo "==============================================="
44
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
45
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
46
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
47
+ echo "Started: $(date)"
48
+ echo "Output dir: ${OUT}"
49
+ echo "Progress log (tee): ${RUN_LOG}"
50
+ echo "==============================================="
51
+
52
+ set -o pipefail
53
+ python -u "${ROOT}/ddpm_posterior_corrected.py" --output-dir "${OUT}" "$@" 2>&1 | tee -a "${RUN_LOG}"
54
+
55
+ echo "==============================================="
56
+ echo "Finished: $(date)"
57
+ echo "Figures & log → ${OUT}"
58
+ echo "==============================================="
cross_model/scripts/run_ddpm_posterior_six_anchors.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --job-name=ddpm_post6
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-posterior-six-%j.out
11
+ #SBATCH --error=slurm-ddpm-posterior-six-%j.err
12
+
13
+ # Six-anchor surrogate posteriors (DDPM-2 + DDPM-6 with extra dims min/max).
14
+ #
15
+ # Submit from anywhere:
16
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_six_anchors.sh
17
+ #
18
+ # Override output directory:
19
+ # sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/my_post_out \
20
+ # /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_six_anchors.sh
21
+ #
22
+ # Extra CLI passes through to ddpm_posterior_six_anchors.py, e.g. only DDPM-6 panels:
23
+ # sbatch .../run_ddpm_posterior_six_anchors.sh --ddpm6-only --grid 12
24
+ #
25
+ # Interactive:
26
+ # module load python/miniconda3-py3.12-usr
27
+ # bash .../run_ddpm_posterior_six_anchors.sh
28
+
29
+ set -euo pipefail
30
+
31
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
32
+ cd "$ROOT"
33
+
34
+ module load python/miniconda3-py3.12-usr
35
+
36
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_six_anchors_out}"
37
+
38
+ echo "==============================================="
39
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
40
+ echo "Job Name: ${SLURM_JOB_NAME:-run_ddpm_posterior_six_anchors}"
41
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
42
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
43
+ echo "Starting Time: $(date)"
44
+ echo "Posterior output: ${OUT}"
45
+ echo "==============================================="
46
+
47
+ python "$ROOT/scripts/ddpm_posterior_six_anchors.py" --output-dir "$OUT" "$@"
48
+
49
+ echo "==============================================="
50
+ echo "Artifacts -> ${OUT}"
51
+ echo "Finished at: $(date)"
52
+ echo "==============================================="
cross_model/scripts/run_poster.sh ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=24:00:00
8
+ #SBATCH --job-name=ddpm_poster
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-poster-%j.out
11
+ #SBATCH --error=slurm-ddpm-poster-%j.err
12
+
13
+ # Corrected six-anchor surrogate posteriors (poster.py): DDPM-2 + DDPM-6 with
14
+ # stochastic averaging, calibrated sigma_pk, MC marginalisation for 6-param, etc.
15
+ #
16
+ # Submit from anywhere:
17
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh
18
+ #
19
+ # Override output directory:
20
+ # sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/my_poster_out \
21
+ # /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh
22
+ #
23
+ # Extra CLI passes through to poster.py, e.g. DDPM-2 only (faster debug):
24
+ # sbatch .../run_poster.sh --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1 --no-ppc
25
+ #
26
+ # Interactive (same module as other Models scripts):
27
+ # module load python/miniconda3-py3.12-usr
28
+ # bash /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh --help
29
+
30
+ set -euo pipefail
31
+
32
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
33
+ cd "$ROOT"
34
+
35
+ module load python/miniconda3-py3.12-usr
36
+
37
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_corrected_out}"
38
+
39
+ echo "==============================================="
40
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
41
+ echo "Job Name: ${SLURM_JOB_NAME:-run_poster}"
42
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
43
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
44
+ echo "Starting Time: $(date)"
45
+ echo "Poster output: ${OUT}"
46
+ echo "==============================================="
47
+
48
+ python "$ROOT/poster.py" --output-dir "$OUT" "$@"
49
+
50
+ echo "==============================================="
51
+ echo "Artifacts -> ${OUT}"
52
+ echo "Finished at: $(date)"
53
+ echo "==============================================="
cross_model/scripts/run_posterior_inference.sh ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --job-name=vlb_infer
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-vlb-infer-%j.out
11
+ #SBATCH --error=slurm-vlb-infer-%j.err
12
+
13
+ # VLB / Mudur-style posterior_inference.py (pure inference-time L_t surfaces).
14
+ #
15
+ # Defaults match bundled 6-param checkpoint + LH test data (override via env).
16
+ #
17
+ # Submit:
18
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_posterior_inference.sh
19
+ #
20
+ # Defaults (posterior_inference.py): n_fields=9, grid_size=10000 (needs --allow_huge_grid),
21
+ # mosaic figure posterior_L0_mosaic_3x3.png at ~10000×10000 px.
22
+ # Override grid without huge scan, e.g.: --grid_size 50 (then --allow_huge_grid not needed)
23
+ # Smoke test:
24
+ # sbatch .../run_posterior_inference.sh --n_fields 1 --grid_size 25 --t_subset 0 --batch_size 16
25
+ #
26
+ # Custom checkpoint / args / data:
27
+ # sbatch --export=CHECKPOINT=/path/best_model.pt,TRAINING_ARGS=/path/args.json,DATA_DIR=/path/params_6 \\
28
+ # .../run_posterior_inference.sh --grid_size 40
29
+ #
30
+ # Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt (override CUSTOM_LOG).
31
+
32
+ set -euo pipefail
33
+
34
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
35
+ cd "$ROOT"
36
+
37
+ module load python/miniconda3-py3.12-usr
38
+
39
+ PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
40
+ OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs}"
41
+
42
+ CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
43
+ ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
44
+ DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
45
+
46
+ mkdir -p "${OUT}"
47
+ RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
48
+
49
+ echo "==============================================="
50
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
51
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
52
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
53
+ echo "Started: $(date)"
54
+ echo "Python: ${PY}"
55
+ echo "checkpoint: ${CHK}"
56
+ echo "training_args: ${ARGS}"
57
+ echo "data_dir: ${DATA}"
58
+ echo "output_dir: ${OUT}"
59
+ echo "Progress log: ${RUN_LOG}"
60
+ echo "==============================================="
61
+
62
+ set -o pipefail
63
+ python -u "${PY}" \
64
+ --checkpoint "${CHK}" \
65
+ --training_args "${ARGS}" \
66
+ --data_dir "${DATA}" \
67
+ --output_dir "${OUT}" \
68
+ --allow_huge_grid \
69
+ "$@" 2>&1 | tee -a "${RUN_LOG}"
70
+
71
+ echo "==============================================="
72
+ echo "Finished: $(date)"
73
+ echo "Artifacts → ${OUT}"
74
+ echo "==============================================="
cross_model/scripts/run_triangle_ddpm_both.sh ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --job-name=ddpm_triangle
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-ddpm-triangle-%j.out
11
+ #SBATCH --error=slurm-ddpm-triangle-%j.err
12
+
13
+ # Run surrogate Ωm–σ8 chain export for DDPM-2 and DDPM-6, then a joint triangle plot.
14
+ # If you have your own copies of ddpm_triangle_integration.py / triangle_plot_posterior.py
15
+ # under $ROOT/scripts (e.g. copied from ~/Downloads), they override the repo versions.
16
+ #
17
+ # sbatch .../run_triangle_ddpm_both.sh
18
+ #
19
+ # sbatch --export=OUTPUT_DIR=/path/to/out,TEST_INDEX=56 .../run_triangle_ddpm_both.sh
20
+ #
21
+ # Interactive:
22
+ # module load python/miniconda3-py3.12-usr
23
+ # bash .../run_triangle_ddpm_both.sh
24
+
25
+ set -euo pipefail
26
+
27
+ ROOT="/scratch/mrpcol001/Diffusion_job/Models"
28
+ cd "$ROOT"
29
+
30
+ module load python/miniconda3-py3.12-usr
31
+
32
+ OUT="${OUTPUT_DIR:-${ROOT}/ddpm_triangle_out}"
33
+ TEST_IX="${TEST_INDEX:-56}"
34
+ GRID="${GRID_POINTS:-14}"
35
+
36
+ mkdir -p "${OUT}"
37
+
38
+ INTEG="${DDPM_TRIANGLE_INTEGRATION_PY:-${ROOT}/scripts/ddpm_triangle_integration.py}"
39
+ PLOT="${DDPM_TRIANGLE_POSTERIOR_PY:-${ROOT}/scripts/triangle_plot_posterior.py}"
40
+
41
+ for f in "$INTEG" "$PLOT"; do
42
+ if [[ ! -f "$f" ]]; then
43
+ echo "Missing: $f"
44
+ exit 1
45
+ fi
46
+ done
47
+
48
+ CHAIN2="${OUT}/chain_surrogate_ix${TEST_IX}_ddpm2.npz"
49
+ CHAIN6="${OUT}/chain_surrogate_ix${TEST_IX}_ddpm6_truth_tail.npz"
50
+
51
+ echo "==============================================="
52
+ echo "Job: ${SLURM_JOB_ID:-local} OUT=${OUT} test_ix=${TEST_IX}"
53
+ echo "==============================================="
54
+
55
+ python "$INTEG" \
56
+ --label-dim 2 \
57
+ --test-index "${TEST_IX}" \
58
+ --grid "${GRID}" \
59
+ -o "${CHAIN2}"
60
+
61
+ python "$INTEG" \
62
+ --label-dim 6 \
63
+ --test-index "${TEST_IX}" \
64
+ --six-tail-mode truth \
65
+ --grid "${GRID}" \
66
+ -o "${CHAIN6}"
67
+
68
+ python "$PLOT" \
69
+ -i "${CHAIN2}" "${CHAIN6}" \
70
+ --labels "DDPM-2" "DDPM-6" \
71
+ -o "${OUT}/triangle_ddpm2_ddpm6_ix${TEST_IX}.png"
72
+
73
+ echo "Chains: ${CHAIN2} ${CHAIN6}"
74
+ echo "Triangle: ${OUT}/triangle_ddpm2_ddpm6_ix${TEST_IX}.png"
75
+ echo "Finished: $(date)"
cross_model/scripts/sigma_contour_utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HDR-style contour levels for 2D probability maps on a grid."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ def compute_sigma_levels(
9
+ posterior_norm: np.ndarray,
10
+ credibility_mass: tuple[float, ...] | list[float],
11
+ ) -> list[float]:
12
+ """
13
+ Highest-density containment: find density thresholds such that descending sort
14
+ of mass covers ``credibility_mass[j]`` of total probability.
15
+
16
+ Returned levels are ascending (suitable order for matplotlib ``contour``).
17
+ """
18
+ p = np.asarray(posterior_norm, dtype=np.float64).ravel()
19
+ s = p.sum()
20
+ if s <= 0:
21
+ return [0.0 for _ in credibility_mass]
22
+ ps = np.sort((p / s).flatten())[::-1]
23
+ cdf = np.cumsum(ps)
24
+ out: list[float] = []
25
+ for cred in credibility_mass:
26
+ j = int(np.searchsorted(cdf, cred, side="left"))
27
+ j = min(max(j, 0), len(ps) - 1)
28
+ out.append(float(ps[j]))
29
+ return sorted(out)
cross_model/scripts/triangle_plot_posterior.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``.
4
+
5
+ Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws
6
+ 1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs``
7
+ and the expected ``.npz`` keys compatible.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ from pathlib import Path
14
+
15
+ import matplotlib
16
+
17
+ matplotlib.use("Agg")
18
+ import matplotlib.pyplot as plt
19
+ import numpy as np
20
+
21
+
22
+ def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]:
23
+ d = np.load(path, allow_pickle=True)
24
+ if "samples" in d:
25
+ s = np.asarray(d["samples"], dtype=np.float64)
26
+ om, s8 = s[:, 0], s[:, 1]
27
+ else:
28
+ om = np.asarray(d["omega_m"], dtype=np.float64).ravel()
29
+ s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel()
30
+ truth = None
31
+ if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files:
32
+ truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"]))
33
+ return om, s8, truth
34
+
35
+
36
+ def main() -> None:
37
+ p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.")
38
+ p.add_argument(
39
+ "--inputs",
40
+ "-i",
41
+ nargs="+",
42
+ type=Path,
43
+ required=True,
44
+ help="One or two .npz outputs from ddpm_triangle_integration.py",
45
+ )
46
+ p.add_argument(
47
+ "--labels",
48
+ nargs="*",
49
+ default=None,
50
+ help="Legend entries (default: paths' stems).",
51
+ )
52
+ p.add_argument(
53
+ "--output",
54
+ "-o",
55
+ type=Path,
56
+ default=None,
57
+ help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).",
58
+ )
59
+ p.add_argument("--bins-1d", type=int, default=40)
60
+ p.add_argument("--bins-2d", type=int, default=45)
61
+ args = p.parse_args()
62
+
63
+ paths = [Path(x).resolve() for x in args.inputs]
64
+ names = args.labels if args.labels else [p.stem for p in paths]
65
+ if len(names) != len(paths):
66
+ raise SystemExit("--labels count must match --inputs")
67
+
68
+ colors = ("#1f77b4", "#d95f02", "#2ca02c")
69
+ fig = plt.figure(figsize=(8.2, 8.0))
70
+
71
+ ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35])
72
+ ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35])
73
+ ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35])
74
+ ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35])
75
+ ax_blank.axis("off")
76
+
77
+ for i, path in enumerate(paths):
78
+ om, s8, truth = _load_chain(path)
79
+ c = colors[i % len(colors)]
80
+ ax00.hist(
81
+ om,
82
+ bins=args.bins_1d,
83
+ density=True,
84
+ histtype="step",
85
+ color=c,
86
+ lw=2.0,
87
+ label=names[i],
88
+ )
89
+ ax11.hist(
90
+ s8,
91
+ bins=args.bins_1d,
92
+ density=True,
93
+ histtype="step",
94
+ color=c,
95
+ lw=2.0,
96
+ )
97
+ h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True)
98
+ xc = 0.5 * (xe[1:] + xe[:-1])
99
+ yc = 0.5 * (ye[1:] + ye[:-1])
100
+ X, Y = np.meshgrid(xc, yc, indexing="ij")
101
+ Z = np.ma.masked_where(h2.T <= 1e-20, h2.T)
102
+ if i == 0 and np.ma.count(Z) > 0:
103
+ cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues")
104
+ fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04)
105
+ elif np.ma.count(Z) > 0:
106
+ ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85])
107
+ if truth:
108
+ tx, ty = truth
109
+ ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6)
110
+
111
+ ax00.set_title(r"$P(\Omega_m)$ marginal")
112
+ ax00.set_ylabel("density")
113
+ ax00.legend(fontsize=8, loc="upper right")
114
+ ax_cont.set_title(r"$2D$ surrogate posterior density")
115
+ ax_cont.set_xlabel(r"$\Omega_m$")
116
+ ax_cont.set_ylabel(r"$\sigma_8$")
117
+ ax11.set_title(r"$P(\sigma_8)$ marginal")
118
+ ax11.set_xlabel("density")
119
+
120
+ out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png"))
121
+ out.parent.mkdir(parents=True, exist_ok=True)
122
+ fig.savefig(out, dpi=170, bbox_inches="tight")
123
+ plt.close(fig)
124
+ print("Saved", out)
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
cross_model/submit_vlb_1000grid.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Submit 1000×1000 grid VLB inference job.
4
+ Usage: python3 submit_vlb_1000grid.py [--dry-run]
5
+ """
6
+
7
+ import subprocess
8
+ import argparse
9
+ from pathlib import Path
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(
13
+ description="Submit high-resolution VLB inference (1000×1000 grid)"
14
+ )
15
+ parser.add_argument("--dry-run", action="store_true",
16
+ help="Print command without submitting")
17
+ parser.add_argument("--grid-size", type=int, default=1000,
18
+ help="Grid resolution (default: 1000)")
19
+ parser.add_argument("--n-fields", type=int, default=9,
20
+ help="Number of test fields (default: 9)")
21
+ parser.add_argument("--job-name", default="vlb-infer-1000",
22
+ help="SLURM job name")
23
+ parser.add_argument("--time", default="24:00:00",
24
+ help="SLURM time limit (default: 24:00:00)")
25
+ parser.add_argument("--mem", default="32G",
26
+ help="Memory requirement (default: 32G)")
27
+ parser.add_argument("--batch-size", type=int, default=32,
28
+ help="Batch size (default: 32, reduce if OOM)")
29
+ args = parser.parse_args()
30
+
31
+ script_dir = Path(__file__).parent.resolve()
32
+ checkpoint = script_dir / "notebook_model_weights/6param_best/best_model.pt"
33
+ training_args = script_dir / "notebook_model_weights/6param_best/args.json"
34
+ data_dir = script_dir.parent / "data/LH_data/params_6"
35
+ output_dir = script_dir / f"vlb_inference_outputs_{args.grid_size}grid"
36
+
37
+ # Validate paths
38
+ for path, name in [
39
+ (checkpoint, "checkpoint"),
40
+ (training_args, "training_args"),
41
+ (data_dir, "data_dir"),
42
+ ]:
43
+ if not path.exists():
44
+ print(f"❌ Error: {name} not found at {path}")
45
+ return 1
46
+
47
+ cmd = [
48
+ "sbatch",
49
+ f"--job-name={args.job_name}",
50
+ f"--time={args.time}",
51
+ f"--mem={args.mem}",
52
+ f"--output=slurm-vlb-infer-{args.grid_size}-%j.out",
53
+ f"--error=slurm-vlb-infer-{args.grid_size}-%j.err",
54
+ str(script_dir / "run_vlb_inference_1000grid.sh"),
55
+ ]
56
+
57
+ # Build environment variables to override defaults in script
58
+ env_cmd = [
59
+ f"GRID_SIZE={args.grid_size}",
60
+ f"N_FIELDS={args.n_fields}",
61
+ f"BATCH_SIZE={args.batch_size}",
62
+ f"OUTPUT_DIR={output_dir}",
63
+ ]
64
+
65
+ full_cmd = env_cmd + cmd
66
+
67
+ print("=" * 60)
68
+ print("VLB Inference Submission — 1000×1000 Grid")
69
+ print("=" * 60)
70
+ print(f"Grid size: {args.grid_size}×{args.grid_size}")
71
+ print(f"Number of fields: {args.n_fields}")
72
+ print(f"Batch size: {args.batch_size}")
73
+ print(f"Output dir: {output_dir}")
74
+ print(f"Memory: {args.mem}")
75
+ print(f"Time limit: {args.time}")
76
+ print("=" * 60)
77
+
78
+ # Compute estimates
79
+ grid_points = args.grid_size ** 2
80
+ timesteps = 8
81
+ seeds = 4
82
+ n_forward_passes = grid_points * timesteps * seeds * args.n_fields
83
+ print(f"\nComputation scale:")
84
+ print(f" Total forward passes: {n_forward_passes:,}")
85
+ print(f" Per field: {n_forward_passes // args.n_fields:,}")
86
+ print(f" Est. time per field: ~{(n_forward_passes // args.n_fields) // 10_000:,} min")
87
+ print(f" Est. total time: ~{(n_forward_passes // 10_000_000):.1f}-{(n_forward_passes // 7_000_000):.1f} hours")
88
+
89
+ print(f"\nCommand: {' '.join(full_cmd)}\n")
90
+
91
+ if args.dry_run:
92
+ print("✓ Dry-run (not submitted)")
93
+ return 0
94
+
95
+ try:
96
+ result = subprocess.run(full_cmd, check=True, capture_output=True, text=True)
97
+ print("✓ Job submitted!")
98
+ print(result.stdout)
99
+ return 0
100
+ except subprocess.CalledProcessError as e:
101
+ print(f"❌ Submission failed: {e}")
102
+ print(e.stderr)
103
+ return 1
104
+
105
+ if __name__ == "__main__":
106
+ exit(main())
scripts/shell/evaluate_conditional_lh6.sh ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=04:00:00
8
+ #SBATCH --job-name=ddpm_hi_lh6_eval
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-eval-%j.out
11
+ #SBATCH --error=slurm-eval-%j.err
12
+
13
+ # Evaluate conditional DDPM (6 CAMELS LH parameters).
14
+ # Submit:
15
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/evaluate_conditional_lh6.sh
16
+ #
17
+ # Optional overrides (example):
18
+ # sbatch --export=CHECKPOINT=/path/to/best_model.pt,OUTPUT_DIR=/path/to/eval_out evaluate_conditional_lh6.sh
19
+
20
+ REPO="/scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6"
21
+ cd "${REPO}" || exit 1
22
+
23
+ module load python/miniconda3-py3.12-usr
24
+
25
+ DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
26
+ # Default: trained run kept under april_26 (large artifacts not duplicated here).
27
+ CHECKPOINT="${CHECKPOINT:-/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt}"
28
+ OUTPUT_DIR="${OUTPUT_DIR:-${REPO}/evaluation_outputs_6param}"
29
+ TRAINING_ARGS="${TRAINING_ARGS:-}"
30
+
31
+ echo "==============================================="
32
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
33
+ echo "Job Name: ${SLURM_JOB_NAME:-evaluate_conditional_lh6}"
34
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
35
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
36
+ echo "Starting Time: $(date)"
37
+ echo "CHECKPOINT: ${CHECKPOINT}"
38
+ echo "DATA_DIR: ${DATA_DIR}"
39
+ echo "OUTPUT_DIR: ${OUTPUT_DIR}"
40
+ echo "==============================================="
41
+
42
+ EVAL_ARGS=(
43
+ python evaluate_conditional.py
44
+ --checkpoint "${CHECKPOINT}"
45
+ --data_dir "${DATA_DIR}"
46
+ --output_dir "${OUTPUT_DIR}"
47
+ --split test
48
+ --num_samples 8
49
+ --ddim_steps 50
50
+ )
51
+
52
+ if [[ -n "${TRAINING_ARGS}" ]]; then
53
+ EVAL_ARGS+=(--training_args "${TRAINING_ARGS}")
54
+ fi
55
+
56
+ "${EVAL_ARGS[@]}"
57
+
58
+ echo "==============================================="
59
+ echo "Evaluation completed at: $(date)"
60
+ echo "Plots and evaluation_data.npz under: ${OUTPUT_DIR}"
61
+ echo "==============================================="
scripts/shell/plot_r2_cosmology_lhs.sh ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH --job-name=ddpm_r2_lhs
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-r2-lhs-%j.out
11
+ #SBATCH --error=slurm-r2-lhs-%j.err
12
+
13
+ # Latin-hypercube R² figure (plot_r2_cosmology_lhs.py): μ(P) and σ(P) vs (Ωm, σ8).
14
+ #
15
+ # Submit (full DDIM run — slow):
16
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/plot_r2_cosmology_lhs.sh
17
+ #
18
+ # Plot only from saved NPZ (fast):
19
+ # sbatch --export=FROM_NPZ=/path/to/r2_lhs_data.npz /scratch/.../plot_r2_cosmology_lhs.sh
20
+ #
21
+ # Optional env vars:
22
+ # CHECKPOINT, DATA_DIR, OUTPUT_PNG, SAVE_NPZ, LHS_N, MAPS_PER_POINT, DDIM_STEPS, SEED
23
+
24
+ REPO="/scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6"
25
+ cd "${REPO}" || exit 1
26
+
27
+ module load python/miniconda3-py3.12-usr
28
+
29
+ DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
30
+ CHECKPOINT="${CHECKPOINT:-/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt}"
31
+ OUTPUT_PNG="${OUTPUT_PNG:-${REPO}/ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png}"
32
+ FROM_NPZ="${FROM_NPZ:-}"
33
+ SAVE_NPZ="${SAVE_NPZ:-}"
34
+ LHS_N="${LHS_N:-50}"
35
+ MAPS_PER_POINT="${MAPS_PER_POINT:-15}"
36
+ DDIM_STEPS="${DDIM_STEPS:-50}"
37
+ SEED="${SEED:-42}"
38
+
39
+ echo "==============================================="
40
+ echo "Job ID: ${SLURM_JOB_ID:-local}"
41
+ echo "Job Name: ${SLURM_JOB_NAME:-plot_r2_cosmology_lhs}"
42
+ echo "Node: ${SLURM_NODELIST:-$(hostname)}"
43
+ echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
44
+ echo "Starting Time: $(date)"
45
+ echo "OUTPUT_PNG: ${OUTPUT_PNG}"
46
+ echo "FROM_NPZ: ${FROM_NPZ:-(none — full compute)}"
47
+ echo "==============================================="
48
+
49
+ PY_ARGS=(
50
+ python plot_r2_cosmology_lhs.py
51
+ --output "${OUTPUT_PNG}"
52
+ --lhs-n "${LHS_N}"
53
+ --maps-per-point "${MAPS_PER_POINT}"
54
+ --ddim-steps "${DDIM_STEPS}"
55
+ --seed "${SEED}"
56
+ )
57
+
58
+ if [[ -n "${FROM_NPZ}" ]]; then
59
+ PY_ARGS+=(--from-npz "${FROM_NPZ}")
60
+ else
61
+ PY_ARGS+=(--checkpoint "${CHECKPOINT}" --data-dir "${DATA_DIR}")
62
+ if [[ -n "${SAVE_NPZ}" ]]; then
63
+ PY_ARGS+=(--save-npz "${SAVE_NPZ}")
64
+ fi
65
+ fi
66
+
67
+ "${PY_ARGS[@]}"
68
+
69
+ echo "==============================================="
70
+ echo "Finished at: $(date)"
71
+ echo "Figure: ${OUTPUT_PNG}"
72
+ echo "==============================================="
scripts/shell/train_conditional_lh6.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --account=l40sfree
3
+ #SBATCH --partition=l40s
4
+ #SBATCH --nodes=1
5
+ #SBATCH --ntasks=8
6
+ #SBATCH --gres=gpu:l40s:1
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --job-name=ddpm_hi_lh6
9
+ #SBATCH --mail-user=mrpcol001@myuct.ac.za
10
+ #SBATCH --output=slurm-%j.out
11
+ #SBATCH --error=slurm-%j.err
12
+
13
+ # Conditional DDPM training — 6 CAMELS LH parameters (ddpm_hi_lh6).
14
+ # Submit from anywhere:
15
+ # sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh
16
+ #
17
+ # Override data path (optional): any folder containing *_LH_6.npy and *_labels_LH.npy
18
+ # sbatch --export=DATA_DIR=/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6 train_conditional_lh6.sh
19
+
20
+ cd /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6
21
+
22
+ module load python/miniconda3-py3.12-usr
23
+
24
+ # Same LH_data layout as DDPM_HI_Emulation_improved (params_2 for 2 labels → params_6 here).
25
+ DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
26
+
27
+ echo "==============================================="
28
+ echo "Job ID: $SLURM_JOB_ID"
29
+ echo "Job Name: $SLURM_JOB_NAME"
30
+ echo "Node: $SLURM_NODELIST"
31
+ echo "GPU: $CUDA_VISIBLE_DEVICES"
32
+ echo "Starting Time: $(date)"
33
+ echo "Conditional diffusion training (ddpm_hi_lh6, 6 labels)"
34
+ echo "DATA_DIR: ${DATA_DIR}"
35
+ echo "==============================================="
36
+
37
+ python train_conditional.py \
38
+ --label_dim 6 \
39
+ --timesteps 1500 \
40
+ --use_ddim \
41
+ --ddim_steps 50 \
42
+ --normalize_labels \
43
+ --batch_size 8 \
44
+ --epochs 200 \
45
+ --lr 2e-4 \
46
+ --early_stop_patience 100 \
47
+ --sample_every 10 \
48
+ --base_channels 64 \
49
+ --channel_multipliers 1 2 4 8 \
50
+ --attention_levels 2 3 \
51
+ --data_dir "${DATA_DIR}" \
52
+ --output_dir outputs_conditional_6param \
53
+ --use_amp
54
+
55
+ # To resume: point --resume at checkpoints/checkpoint_epoch_N.pt and set --epochs to the
56
+ # new total; add --resume_refresh_scheduler if extending past the original epoch count.
57
+
58
+ echo "==============================================="
59
+ echo "Training completed at: $(date)"
60
+ echo "==============================================="
src/eval_model.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches.
3
+
4
+ Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+ import torch
10
+ from matplotlib.colors import LinearSegmentedColormap
11
+
12
+ import evaluate_conditional as ec
13
+
14
+ LO_LOG, HI_LOG = 14.0, 22.0
15
+
16
+
17
+ def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float:
18
+ """Univariate R² (same as sklearn for 1D arrays)."""
19
+ y_true = np.asarray(y_true, dtype=np.float64).ravel()
20
+ y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
21
+ ss_res = np.sum((y_true - y_pred) ** 2)
22
+ ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
23
+ if ss_tot < 1e-30:
24
+ return 0.0 if ss_res < 1e-30 else float("-inf")
25
+ return float(1.0 - ss_res / ss_tot)
26
+
27
+
28
+ def cmap_r2_hiflow() -> LinearSegmentedColormap:
29
+ """Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style)."""
30
+ return LinearSegmentedColormap.from_list(
31
+ "r2_hiflow",
32
+ ["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"],
33
+ N=256,
34
+ )
35
+
36
+
37
+ def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray:
38
+ """Maps in [0,1] linear in column density → log10(N_HI/cm^-2)."""
39
+ return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64)
40
+
41
+
42
+ def per_map_power_spectra_log(
43
+ images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG
44
+ ) -> tuple[np.ndarray, np.ndarray]:
45
+ """Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field."""
46
+ logf = images01_to_log_nhi(images_01, lo, hi)
47
+ n = logf.shape[0]
48
+ npix = logf.shape[-1]
49
+ dl = box_size / npix
50
+ dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl)
51
+ pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)])
52
+ return dk, pks
53
+
54
+
55
+ def sample_batch(
56
+ model: torch.nn.Module,
57
+ labels_np: np.ndarray,
58
+ label_mean: np.ndarray,
59
+ label_std: np.ndarray,
60
+ normalize_labels: bool,
61
+ height: int,
62
+ width: int,
63
+ device: torch.device,
64
+ ddim_steps: int,
65
+ progress: bool,
66
+ ) -> np.ndarray:
67
+ """DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim."""
68
+ labels_np = np.asarray(labels_np, dtype=np.float32)
69
+ mean = np.asarray(label_mean, dtype=np.float32)
70
+ std = np.asarray(label_std, dtype=np.float32)
71
+ if normalize_labels:
72
+ t = ec.prepare_labels_for_model(labels_np, mean, std).to(device)
73
+ else:
74
+ t = torch.from_numpy(labels_np).float().to(device)
75
+ with torch.no_grad():
76
+ out = model.sample(
77
+ labels=t,
78
+ channels=1,
79
+ height=height,
80
+ width=width,
81
+ device=device,
82
+ progress=progress,
83
+ use_ddim=True,
84
+ ddim_steps=ddim_steps,
85
+ )
86
+ return ec.from_model_output(out)
src/figure9_posterior.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Figure-9 style surrogate posteriors: build (Ωm, σ8) grids and log P(k) for observed maps.
3
+
4
+ Used by ddpm_cond_eval.ipynb. Sampling and P(k) live in eval_model.py.
5
+ """
6
+ from __future__ import annotations
7
+
8
+ import numpy as np
9
+
10
+ import eval_model as em
11
+
12
+
13
+ def build_cosmo_grid(
14
+ g: int,
15
+ om_lo: float,
16
+ om_hi: float,
17
+ s8_lo: float,
18
+ s8_hi: float,
19
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
20
+ om_axis = np.linspace(om_lo, om_hi, g, dtype=np.float64)
21
+ s8_axis = np.linspace(s8_lo, s8_hi, g, dtype=np.float64)
22
+ og, sg = np.meshgrid(om_axis, s8_axis, indexing="ij")
23
+ grid_labels = np.stack([og.ravel(), sg.ravel()], axis=1).astype(np.float32)
24
+ return om_axis, s8_axis, og, sg, grid_labels
25
+
26
+
27
+ def log_pk_observed(img01: np.ndarray, box_size: float, dk: np.ndarray) -> np.ndarray:
28
+ """Single map → log P(k) on bins where dk > 0."""
29
+ _, pk = em.per_map_power_spectra_log(img01[np.newaxis, ...], box_size)
30
+ valid = dk > 0
31
+ if pk.shape[1] != len(dk):
32
+ raise ValueError("P(k) bin count mismatch vs dk")
33
+ return np.log(pk[0, valid] + 1e-30)
src/plot_r2_cosmology_lhs.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Reproduce the Latin-hypercube R² figure (μ(P) and σ(P)) in (Ωm, σ8) with a layout
4
+ that avoids colorbar / suptitle overlap.
5
+
6
+ Usage (full run — slow):
7
+ python plot_r2_cosmology_lhs.py --output ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png
8
+
9
+ Replay plot only from saved arrays:
10
+ python plot_r2_cosmology_lhs.py --from-npz r2_lhs_data.npz --output out.png
11
+
12
+ Defaults match ddpm_conditional / evaluate_conditional 6-param setup.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ import matplotlib
20
+
21
+ matplotlib.use("Agg")
22
+ import matplotlib.pyplot as plt
23
+ import numpy as np
24
+ import torch
25
+ from matplotlib.cm import ScalarMappable
26
+ from matplotlib.colors import Normalize
27
+ from matplotlib.gridspec import GridSpec
28
+
29
+ import evaluate_conditional as ec
30
+ import eval_model as em
31
+
32
+ _SCRIPT_DIR = Path(__file__).resolve().parent
33
+ _DEFAULT_CKPT = _SCRIPT_DIR / "outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt"
34
+ _DEFAULT_DATA = "/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"
35
+
36
+
37
+ def latin_hypercube_scaled(
38
+ n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator
39
+ ) -> np.ndarray:
40
+ """n points in [lo, hi] per dimension (classic LHS)."""
41
+ d = int(lo.shape[0])
42
+ u = rng.random((n, d))
43
+ cut = np.linspace(0.0, 1.0, n + 1)
44
+ a, b = cut[:-1], cut[1:]
45
+ width = (b - a)[:, np.newaxis]
46
+ rd = a[:, np.newaxis] + u * width
47
+ for j in range(d):
48
+ rng.shuffle(rd[:, j])
49
+ span = (hi - lo).astype(np.float64)
50
+ return (lo + rd * span).astype(np.float32)
51
+
52
+
53
+ def compute_lhs_r2(
54
+ model: torch.nn.Module,
55
+ images_split: np.ndarray,
56
+ labels_split: np.ndarray,
57
+ label_mean: np.ndarray,
58
+ label_std: np.ndarray,
59
+ device: torch.device,
60
+ lhs_n: int,
61
+ maps_per_point: int,
62
+ batch_size: int,
63
+ box_size_mpc: float,
64
+ ddim_steps: int,
65
+ seed: int,
66
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
67
+ """Returns lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b."""
68
+ lo_b = labels_split.min(axis=0)
69
+ hi_b = labels_split.max(axis=0)
70
+ rng = np.random.default_rng(seed)
71
+ lhs_pts = latin_hypercube_scaled(lhs_n, lo_b, hi_b, rng)
72
+
73
+ ldim = labels_split.shape[1]
74
+ h, w = int(images_split.shape[-2]), int(images_split.shape[-1])
75
+ bs = min(batch_size, maps_per_point)
76
+ npix = int(images_split.shape[-1])
77
+ dl = box_size_mpc / npix
78
+
79
+ def pk_stack(imgs: np.ndarray) -> np.ndarray:
80
+ return np.stack([ec.PowerSpectrum(im, N=npix, dl=dl)[1] for im in imgs], axis=0)
81
+
82
+ r2_mu_arr = np.full(lhs_n, np.nan, dtype=np.float64)
83
+ r2_sig_arr = np.full(lhs_n, np.nan, dtype=np.float64)
84
+ model.eval()
85
+
86
+ for ti in range(lhs_n):
87
+ theta = lhs_pts[ti]
88
+ dist = np.linalg.norm(labels_split - theta, axis=1)
89
+ nn_idx = np.argsort(dist)[:maps_per_point]
90
+ real_batch = images_split[nn_idx]
91
+ rep = np.tile(theta[None, :], (maps_per_point, 1))
92
+ gen_chunks = []
93
+ for j in range(0, maps_per_point, bs):
94
+ chunk = rep[j : j + bs]
95
+ bt = ec.prepare_labels_for_model(chunk, label_mean, label_std).to(device)
96
+ with torch.no_grad():
97
+ g = model.sample(
98
+ labels=bt,
99
+ channels=1,
100
+ height=h,
101
+ width=w,
102
+ device=device,
103
+ progress=False,
104
+ use_ddim=True,
105
+ ddim_steps=ddim_steps,
106
+ )
107
+ gen_chunks.append(ec.from_model_output(g))
108
+ gen_batch = np.concatenate(gen_chunks, axis=0)
109
+
110
+ pk_r = pk_stack(real_batch)
111
+ pk_g = pk_stack(gen_batch)
112
+ km = np.arange(pk_r.shape[1], dtype=int) > 0
113
+ mu_r, mu_g = pk_r.mean(axis=0), pk_g.mean(axis=0)
114
+ sr, sg = pk_r.std(axis=0), pk_g.std(axis=0)
115
+ r2_mu_arr[ti] = em.r2_score_1d(mu_r[km], mu_g[km])
116
+ r2_sig_arr[ti] = em.r2_score_1d(sr[km], sg[km])
117
+
118
+ return lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b
119
+
120
+
121
+ def plot_r2_cosmology_figure(
122
+ lhs_pts: np.ndarray,
123
+ r2_mu_arr: np.ndarray,
124
+ r2_sig_arr: np.ndarray,
125
+ lo_b: np.ndarray,
126
+ hi_b: np.ndarray,
127
+ out_path: Path,
128
+ r2_vmin: float = 0.90,
129
+ r2_vmax: float = 1.0,
130
+ lhs_n: int | None = None,
131
+ maps_per_point: int | None = None,
132
+ dpi: int = 160,
133
+ ) -> None:
134
+ """
135
+ Two-panel scatter in (Ωm, σ8) with a dedicated colorbar column (no overlap with heatmap).
136
+ """
137
+ lhs_n = lhs_n if lhs_n is not None else len(r2_mu_arr)
138
+ maps_per_point = maps_per_point if maps_per_point is not None else 15
139
+
140
+ ldim = lhs_pts.shape[1]
141
+ om_plot = lhs_pts[:, 0]
142
+ s8_plot = lhs_pts[:, 1] if ldim >= 2 else np.zeros(lhs_n)
143
+
144
+ cmap = em.cmap_r2_hiflow()
145
+ norm = Normalize(vmin=r2_vmin, vmax=r2_vmax)
146
+ sm = ScalarMappable(norm=norm, cmap=cmap)
147
+ sm.set_array([])
148
+
149
+ fig = plt.figure(figsize=(11.5, 4.9))
150
+ # Left: data panels; narrow right strip: colorbar only (avoids fig.colorbar + tight_layout clash)
151
+ gs = GridSpec(
152
+ nrows=1,
153
+ ncols=3,
154
+ figure=fig,
155
+ width_ratios=[1.0, 1.0, 0.065],
156
+ wspace=0.26,
157
+ left=0.07,
158
+ right=0.98,
159
+ top=0.82,
160
+ bottom=0.14,
161
+ )
162
+ ax0 = fig.add_subplot(gs[0, 0])
163
+ ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
164
+ cax = fig.add_subplot(gs[0, 2])
165
+
166
+ pad_x = 0.02 * (float(hi_b[0] - lo_b[0]) + 1e-6)
167
+ ax0.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
168
+ ax1.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
169
+ if ldim >= 2:
170
+ pad_y = 0.02 * (float(hi_b[1] - lo_b[1]) + 1e-6)
171
+ ax0.set_ylim(float(lo_b[1]) - pad_y, float(hi_b[1]) + pad_y)
172
+
173
+ for ax, r2v, subtitle in zip(
174
+ (ax0, ax1),
175
+ (r2_mu_arr, r2_sig_arr),
176
+ (r"$R^2$ for $\mu(P)$", r"$R^2$ for $\sigma(P)$"),
177
+ ):
178
+ ok = np.isfinite(r2v)
179
+ ax.scatter(
180
+ om_plot[ok],
181
+ s8_plot[ok],
182
+ c=np.clip(r2v[ok], r2_vmin, r2_vmax),
183
+ cmap=cmap,
184
+ norm=norm,
185
+ s=52,
186
+ alpha=0.92,
187
+ edgecolors="k",
188
+ linewidths=0.35,
189
+ )
190
+ ax.set_xlabel(r"$\Omega_m$", fontsize=12)
191
+ ax.set_title(subtitle, fontsize=11)
192
+ ax.grid(True, alpha=0.25)
193
+
194
+ ax0.set_ylabel(r"$\sigma_8$", fontsize=12)
195
+ plt.setp(ax1.get_yticklabels(), visible=False)
196
+
197
+ cb = fig.colorbar(sm, cax=cax)
198
+ cb.set_label(r"$R^2$", fontsize=11)
199
+ cax.tick_params(labelsize=9)
200
+
201
+ fig.suptitle(
202
+ r"Visual summary of $R^2$ (CAMELS vs conditional DDPM) vs cosmology — "
203
+ + f"{lhs_n} Latin Hypercube samples; {maps_per_point} maps / point",
204
+ fontsize=11,
205
+ fontweight="bold",
206
+ y=0.96,
207
+ )
208
+
209
+ out_path = Path(out_path)
210
+ out_path.parent.mkdir(parents=True, exist_ok=True)
211
+ # Do not use bbox_inches="tight" — it rebalance axes and can squeeze the colorbar into the panels.
212
+ fig.savefig(out_path, dpi=dpi)
213
+ plt.close(fig)
214
+
215
+
216
+ def _resolve_training_args(checkpoint: Path) -> Path | None:
217
+ run = checkpoint.parent.parent if checkpoint.parent.name == "checkpoints" else checkpoint.parent
218
+ for name in ("args.json", "args.txt"):
219
+ p = run / name
220
+ if p.is_file():
221
+ return p
222
+ return None
223
+
224
+
225
+ def parse_args() -> argparse.Namespace:
226
+ p = argparse.ArgumentParser(description="LHS R² cosmology figure (fixed colorbar layout)")
227
+ p.add_argument("--checkpoint", type=str, default=str(_DEFAULT_CKPT))
228
+ p.add_argument("--data-dir", type=str, default=_DEFAULT_DATA)
229
+ p.add_argument("--split", type=str, default="test", choices=("train", "val", "test"))
230
+ p.add_argument("--output", type=str, default=str(_SCRIPT_DIR / "ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png"))
231
+ p.add_argument("--from-npz", type=str, default=None, help="Load lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b")
232
+ p.add_argument("--save-npz", type=str, default=None, help="After compute, save arrays for --from-npz replot")
233
+ p.add_argument("--lhs-n", type=int, default=50)
234
+ p.add_argument("--maps-per-point", type=int, default=15)
235
+ p.add_argument("--batch-size", type=int, default=8)
236
+ p.add_argument("--ddim-steps", type=int, default=50)
237
+ p.add_argument("--box-size-mpc", type=float, default=25.0)
238
+ p.add_argument("--seed", type=int, default=42)
239
+ p.add_argument("--r2-vmin", type=float, default=0.90)
240
+ p.add_argument("--r2-vmax", type=float, default=1.0)
241
+ p.add_argument("--dpi", type=int, default=160)
242
+ return p.parse_args()
243
+
244
+
245
+ def main() -> None:
246
+ args = parse_args()
247
+ out_path = Path(args.output)
248
+
249
+ if args.from_npz:
250
+ z = np.load(args.from_npz, allow_pickle=False)
251
+ lhs_pts = z["lhs_pts"]
252
+ r2_mu_arr = z["r2_mu_arr"]
253
+ r2_sig_arr = z["r2_sig_arr"]
254
+ lo_b = z["lo_b"]
255
+ hi_b = z["hi_b"]
256
+ else:
257
+ ckpt = Path(args.checkpoint).expanduser().resolve()
258
+ if not ckpt.is_file():
259
+ raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
260
+
261
+ ta = _resolve_training_args(ckpt)
262
+ config: dict = {}
263
+ if ta is not None:
264
+ config = ec.load_training_config(str(ta))
265
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
266
+ model = ec.build_model(config, device)
267
+ ec.load_checkpoint(model, str(ckpt), device)
268
+
269
+ data_dir = Path(args.data_dir)
270
+ images_split, labels_split = ec.load_split(data_dir, args.split)
271
+ label_mean, label_std = ec.load_label_stats(data_dir)
272
+
273
+ lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b = compute_lhs_r2(
274
+ model,
275
+ images_split,
276
+ labels_split,
277
+ label_mean,
278
+ label_std,
279
+ device,
280
+ lhs_n=args.lhs_n,
281
+ maps_per_point=args.maps_per_point,
282
+ batch_size=args.batch_size,
283
+ box_size_mpc=args.box_size_mpc,
284
+ ddim_steps=args.ddim_steps,
285
+ seed=args.seed,
286
+ )
287
+
288
+ if args.save_npz:
289
+ np.savez(
290
+ args.save_npz,
291
+ lhs_pts=lhs_pts,
292
+ r2_mu_arr=r2_mu_arr,
293
+ r2_sig_arr=r2_sig_arr,
294
+ lo_b=lo_b,
295
+ hi_b=hi_b,
296
+ )
297
+ print("Saved", args.save_npz)
298
+
299
+ plot_r2_cosmology_figure(
300
+ lhs_pts,
301
+ r2_mu_arr,
302
+ r2_sig_arr,
303
+ lo_b,
304
+ hi_b,
305
+ out_path,
306
+ r2_vmin=args.r2_vmin,
307
+ r2_vmax=args.r2_vmax,
308
+ lhs_n=args.lhs_n,
309
+ maps_per_point=args.maps_per_point,
310
+ dpi=args.dpi,
311
+ )
312
+ print("Saved", out_path.resolve())
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
src/posterior_inference.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ posterior_inference.py — VLB-based cosmological inference (Mudur et al. 2023 §4 style).
4
+ Pure inference-time; frozen DDPM weights. Script lives next to diffusion_conditional.py.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import ast
11
+ import json
12
+ import os
13
+ import sys
14
+ import time
15
+ from pathlib import Path
16
+ from typing import Dict, List, Optional, Tuple
17
+
18
+ import matplotlib
19
+ matplotlib.use("Agg")
20
+ import matplotlib.gridspec as gridspec
21
+ import matplotlib.pyplot as plt
22
+ import matplotlib.patheffects as mpathe
23
+ import numpy as np
24
+ import torch
25
+
26
+ # ── Project imports ────────────────────────────────────────────────────────────
27
+ _ROOT = Path(__file__).resolve().parent
28
+ if (_ROOT / "diffusion_conditional.py").is_file():
29
+ sys.path.insert(0, str(_ROOT))
30
+
31
+ from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
32
+ from unet_conditional import ConditionalUNet
33
+
34
+ plt.rcParams.update({
35
+ "figure.facecolor": "white", "axes.facecolor": "white",
36
+ "axes.edgecolor": "#222", "axes.linewidth": 0.7,
37
+ "axes.spines.top": False, "axes.spines.right": False,
38
+ "font.family": "DejaVu Sans", "font.size": 9.5,
39
+ "savefig.facecolor": "white",
40
+ })
41
+
42
+ REAL_COLOR = "#CC3333"
43
+ GEN_COLOR = "#2266BB"
44
+ SIGMA_LEVELS = [2.30, 6.17, 11.83]
45
+ SIGMA_COLORS = ["#1a5c9e", "#5590d0", "#99c0ea"]
46
+ SIGMA_LABELS = {2.30: r"$1\sigma$", 6.17: r"$2\sigma$", 11.83: r"$3\sigma$"}
47
+
48
+
49
+ def load_config(path: str) -> Dict:
50
+ p = Path(path)
51
+ if p.suffix == ".json":
52
+ with open(p) as f:
53
+ return json.load(f)
54
+ cfg = {}
55
+ with open(p) as f:
56
+ for line in f:
57
+ if ":" not in line:
58
+ continue
59
+ k, v = line.strip().split(":", 1)
60
+ try:
61
+ cfg[k.strip()] = ast.literal_eval(v.strip())
62
+ except Exception:
63
+ cfg[k.strip()] = v.strip()
64
+ return cfg
65
+
66
+
67
+ def load_model(ckpt: str, cfg: Dict, device: torch.device) -> ConditionalDiffusionModel:
68
+ unet = ConditionalUNet(
69
+ in_channels=1, out_channels=1,
70
+ label_dim=int(cfg.get("label_dim", 2)),
71
+ base_channels=int(cfg.get("base_channels", 64)),
72
+ channel_multipliers=list(cfg.get("channel_multipliers", [1, 2, 4, 8])),
73
+ attention_levels=list(cfg.get("attention_levels", [2, 3])),
74
+ dropout=float(cfg.get("dropout", 0.1)),
75
+ )
76
+ diff = GaussianDiffusion(
77
+ timesteps=int(cfg.get("timesteps", 1500)),
78
+ beta_start=float(cfg.get("beta_start", 1e-4)),
79
+ beta_end=float(cfg.get("beta_end", 0.02)),
80
+ schedule_type=str(cfg.get("schedule_type", "linear")),
81
+ )
82
+ model = ConditionalDiffusionModel(unet, diff).to(device)
83
+ ck = torch.load(ckpt, map_location=device, weights_only=False)
84
+ if isinstance(ck, dict) and "ema_shadow" in ck:
85
+ cur = model.state_dict()
86
+ for k, v in ck["ema_shadow"].items():
87
+ if k in cur:
88
+ cur[k] = v
89
+ model.load_state_dict(cur)
90
+ print(" Loaded EMA weights")
91
+ elif isinstance(ck, dict) and "model_state_dict" in ck:
92
+ model.load_state_dict(ck["model_state_dict"])
93
+ else:
94
+ model.load_state_dict(ck)
95
+ model.eval()
96
+ for p in model.parameters():
97
+ p.requires_grad_(False)
98
+ return model
99
+
100
+
101
+ def load_test_data(
102
+ data_dir: str, n_fields: int, seed: int = 42
103
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
104
+ dp = Path(data_dir)
105
+ lsuf = "_2" if (dp / "train_labels_LH_2.npy").exists() else ""
106
+ isuf = "" if (dp / "train_LH.npy").exists() else "_6"
107
+
108
+ imgs = np.load(dp / f"test_LH{isuf}.npy").astype(np.float32)
109
+ labels = np.load(dp / f"test_labels_LH{lsuf}.npy").astype(np.float32)
110
+ tr_lab = np.load(dp / f"train_labels_LH{lsuf}.npy").astype(np.float32)
111
+
112
+ rng = np.random.default_rng(seed)
113
+ idx = rng.choice(len(imgs), n_fields, replace=False)
114
+
115
+ label_mu = tr_lab.mean(0)
116
+ label_std = np.where(tr_lab.std(0) == 0, 1.0, tr_lab.std(0))
117
+ return imgs[idx], labels[idx], label_mu, label_std
118
+
119
+
120
+ def normal_kl(mean1, log_var1, mean2, log_var2):
121
+ return 0.5 * (
122
+ -1.0 + log_var2 - log_var1
123
+ + torch.exp(log_var1 - log_var2)
124
+ + ((mean1 - mean2) ** 2) * torch.exp(-log_var2)
125
+ )
126
+
127
+
128
+ def _approx_standard_normal_cdf(x):
129
+ return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3)))
130
+
131
+
132
+ def discretised_gaussian_log_likelihood(x_0, mean, log_var):
133
+ centered_x = x_0 - mean
134
+ inv_stdv = torch.exp(-0.5 * log_var)
135
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
136
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
137
+ cdf_plus = _approx_standard_normal_cdf(plus_in)
138
+ cdf_min = _approx_standard_normal_cdf(min_in)
139
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
140
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
141
+ cdf_delta = (cdf_plus - cdf_min).clamp(min=1e-12)
142
+ log_probs = torch.where(
143
+ x_0 < -0.999,
144
+ log_cdf_plus,
145
+ torch.where(x_0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta)),
146
+ )
147
+ return log_probs
148
+
149
+
150
+ def predict_x_start_from_eps(diff: GaussianDiffusion, x_t: torch.Tensor,
151
+ t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
152
+ # Matches GaussianDiffusion._predict_xstart_from_noise (diffusion_conditional.py)
153
+ return (
154
+ diff._extract(diff.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t
155
+ - diff._extract(diff.sqrt_recip_minus_one, t, x_t.shape) * eps
156
+ )
157
+
158
+
159
+ def q_posterior_mean_var(diff: GaussianDiffusion, x_start: torch.Tensor,
160
+ x_t: torch.Tensor, t: torch.Tensor):
161
+ mean = (
162
+ diff._extract(diff.posterior_mean_coef1, t, x_t.shape) * x_start
163
+ + diff._extract(diff.posterior_mean_coef2, t, x_t.shape) * x_t
164
+ )
165
+ var = diff._extract(diff.posterior_variance, t, x_t.shape)
166
+ log_var_c = diff._extract(diff.posterior_log_variance_clipped, t, x_t.shape)
167
+ return mean, var, log_var_c
168
+
169
+
170
+ @torch.no_grad()
171
+ def compute_L_t(
172
+ model: ConditionalDiffusionModel,
173
+ x_0: torch.Tensor,
174
+ labels_n: torch.Tensor,
175
+ t: int,
176
+ fixed_eps: torch.Tensor,
177
+ ) -> torch.Tensor:
178
+ diff = model.diffusion
179
+ device = x_0.device
180
+ B = x_0.shape[0]
181
+ t_vec = torch.full((B,), t, device=device, dtype=torch.long)
182
+
183
+ if t == 0:
184
+ t1 = torch.full((B,), 1, device=device, dtype=torch.long)
185
+ ab1 = diff._extract(diff.alphas_cumprod, t1, x_0.shape)
186
+ x_1 = torch.sqrt(ab1) * x_0 + torch.sqrt(1.0 - ab1) * fixed_eps
187
+ eps_pred = model(x_1, t1, labels_n)
188
+ x_start_pred = predict_x_start_from_eps(diff, x_1, t1, eps_pred).clamp(-1, 1)
189
+ mean, _, log_var = q_posterior_mean_var(diff, x_start_pred, x_1, t1)
190
+ log_p = discretised_gaussian_log_likelihood(x_0, mean, log_var)
191
+ return -log_p.sum(dim=(1, 2, 3))
192
+
193
+ ab_t = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
194
+ x_t = torch.sqrt(ab_t) * x_0 + torch.sqrt(1.0 - ab_t) * fixed_eps
195
+ true_mean, _, true_log_var = q_posterior_mean_var(diff, x_0, x_t, t_vec)
196
+ eps_pred = model(x_t, t_vec, labels_n)
197
+ x_start_pred = predict_x_start_from_eps(diff, x_t, t_vec, eps_pred).clamp(-1, 1)
198
+ model_mean, _, model_log_var = q_posterior_mean_var(diff, x_start_pred, x_t, t_vec)
199
+ kl = normal_kl(true_mean, true_log_var, model_mean, model_log_var)
200
+ return kl.sum(dim=(1, 2, 3))
201
+
202
+
203
+ @torch.no_grad()
204
+ def compute_L_T_analytic(diff: GaussianDiffusion, x_0: torch.Tensor) -> torch.Tensor:
205
+ T = diff.timesteps
206
+ t_vec = torch.full((x_0.shape[0],), T - 1, device=x_0.device, dtype=torch.long)
207
+ abar_T = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
208
+ mean1 = torch.sqrt(abar_T) * x_0
209
+ log_var1 = torch.log((1.0 - abar_T).clamp(min=1e-30))
210
+ kl = normal_kl(mean1, log_var1, torch.zeros_like(mean1), torch.zeros_like(log_var1))
211
+ return kl.sum(dim=(1, 2, 3))
212
+
213
+
214
+ def build_eval_grid(
215
+ Om_true: float,
216
+ s8_true: float,
217
+ grid_size: int,
218
+ span: float = 0.1,
219
+ Om_range: Tuple[float, float] = (0.10, 0.50),
220
+ s8_range: Tuple[float, float] = (0.60, 1.00),
221
+ ) -> Tuple[np.ndarray, np.ndarray]:
222
+ Om_lo = max(Om_true - span, Om_range[0])
223
+ Om_hi = min(Om_true + span, Om_range[1])
224
+ s8_lo = max(s8_true - span, s8_range[0])
225
+ s8_hi = min(s8_true + span, s8_range[1])
226
+ Om_1d = np.linspace(Om_lo, Om_hi, grid_size)
227
+ s8_1d = np.linspace(s8_lo, s8_hi, grid_size)
228
+ return Om_1d, s8_1d
229
+
230
+
231
+ @torch.no_grad()
232
+ def evaluate_vlb_surface(
233
+ model: ConditionalDiffusionModel,
234
+ x_0: torch.Tensor,
235
+ Om_grid: np.ndarray,
236
+ s8_grid: np.ndarray,
237
+ label_mu: np.ndarray,
238
+ label_std: np.ndarray,
239
+ t_values: List[int],
240
+ n_seeds: int = 4,
241
+ batch_size: int = 32,
242
+ label_dim: int = 2,
243
+ fixed_seed: int = 0,
244
+ device: Optional[torch.device] = None,
245
+ ) -> Dict[int, np.ndarray]:
246
+ device = device or x_0.device
247
+ nO, nS = len(Om_grid), len(s8_grid)
248
+ n_pts = nO * nS
249
+
250
+ Omg, s8g = np.meshgrid(Om_grid, s8_grid, indexing="ij")
251
+ raw_labels = np.column_stack([Omg.ravel(), s8g.ravel()])
252
+ if label_dim > 2:
253
+ pad = np.zeros((n_pts, label_dim - 2), dtype=np.float32)
254
+ for i in range(label_dim - 2):
255
+ pad[:, i] = label_mu[2 + i]
256
+ raw_labels = np.concatenate([raw_labels, pad], axis=1)
257
+ norm_labels = (raw_labels - label_mu) / label_std
258
+ norm_labels_t = torch.from_numpy(norm_labels.astype(np.float32)).to(device)
259
+
260
+ L_surfaces = {t: np.zeros(n_pts, dtype=np.float64) for t in t_values}
261
+
262
+ H, W = x_0.shape[-2], x_0.shape[-1]
263
+ rng_torch = torch.Generator(device=device).manual_seed(fixed_seed)
264
+ seeds_eps = [
265
+ torch.randn(1, 1, H, W, generator=rng_torch, device=device)
266
+ for _ in range(n_seeds)
267
+ ]
268
+
269
+ for _, fixed_eps in enumerate(seeds_eps):
270
+ for t in t_values:
271
+ for start in range(0, n_pts, batch_size):
272
+ end = min(start + batch_size, n_pts)
273
+ bsz = end - start
274
+ x_b = x_0.expand(bsz, -1, -1, -1)
275
+ lbl_b = norm_labels_t[start:end]
276
+ eps_b = fixed_eps.expand(bsz, -1, -1, -1)
277
+ L_t = compute_L_t(model, x_b, lbl_b, t=t, fixed_eps=eps_b)
278
+ L_surfaces[t][start:end] += L_t.cpu().numpy() / n_seeds
279
+
280
+ return {t: L_surfaces[t].reshape(nO, nS) for t in t_values}
281
+
282
+
283
+ def marginal_from_neg2dL(
284
+ neg2dL: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray
285
+ ) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]:
286
+ L = -0.5 * neg2dL
287
+ L = L - L.max()
288
+ P = np.exp(L)
289
+ Om_marginal = P.sum(axis=1)
290
+ Om_marginal /= Om_marginal.sum()
291
+ s8_marginal = P.sum(axis=0)
292
+ s8_marginal /= s8_marginal.sum()
293
+ Om_pred = float(Om_grid[np.argmax(Om_marginal)])
294
+ s8_pred = float(s8_grid[np.argmax(s8_marginal)])
295
+ return Om_marginal, s8_marginal, (Om_pred, s8_pred)
296
+
297
+
298
+ def credible_interval_68(values: np.ndarray, probs: np.ndarray) -> Tuple[float, float, float]:
299
+ cdf = np.cumsum(probs)
300
+ cdf /= cdf[-1]
301
+ median = float(np.interp(0.50, cdf, values))
302
+ lo = float(np.interp(0.16, cdf, values))
303
+ hi = float(np.interp(0.84, cdf, values))
304
+ return median, lo, hi
305
+
306
+
307
+ def fig_contours_per_t(
308
+ surfaces: Dict[int, np.ndarray],
309
+ Om_grid: np.ndarray,
310
+ s8_grid: np.ndarray,
311
+ Om_true: float,
312
+ s8_true: float,
313
+ out_path: Path,
314
+ dpi: int = 200,
315
+ ) -> None:
316
+ fig, ax = plt.subplots(figsize=(7, 6.5), dpi=dpi)
317
+ cmap = plt.cm.viridis
318
+ n_t = len(surfaces)
319
+ colors = cmap(np.linspace(0.05, 0.95, n_t))
320
+
321
+ for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
322
+ neg2dL = 2.0 * (L_surf - L_surf.min())
323
+ ax.contour(
324
+ Om_grid, s8_grid, neg2dL.T,
325
+ levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
326
+ )
327
+ ax.plot([], [], color=col, lw=1.8, label=f"t={t}")
328
+
329
+ ax.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
330
+ ax.set_xlabel(r"$\Omega_m$", fontsize=12)
331
+ ax.set_ylabel(r"$\sigma_8$", fontsize=12)
332
+ ax.set_title(
333
+ r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contour per timestep"
334
+ "\n(Mudur-style) smaller $t$ → tighter constraint",
335
+ fontweight="bold", fontsize=10,
336
+ )
337
+ ax.legend(fontsize=8, loc="best", ncol=2, framealpha=0.92)
338
+ ax.grid(alpha=0.18)
339
+ ax.set_xlim(Om_grid[0], Om_grid[-1])
340
+ ax.set_ylim(s8_grid[0], s8_grid[-1])
341
+ fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
342
+ plt.close(fig)
343
+ print(f" Saved -> {out_path}")
344
+
345
+
346
+ def _L0_posterior_smoothed(
347
+ L0_surface: np.ndarray, smooth_sigma: float = 0.6,
348
+ ):
349
+ from scipy.ndimage import gaussian_filter as gf
350
+
351
+ neg2dL = 2.0 * (L0_surface - L0_surface.min())
352
+ surface_sm = gf(neg2dL, sigma=smooth_sigma)
353
+ return surface_sm, neg2dL
354
+
355
+
356
+ def draw_L0_posterior_main_panel(
357
+ ax,
358
+ surface_sm: np.ndarray,
359
+ Om_grid: np.ndarray,
360
+ s8_grid: np.ndarray,
361
+ Om_true: float,
362
+ s8_true: float,
363
+ Om_pred: float,
364
+ s8_pred: float,
365
+ *,
366
+ clabel_fontsize: float = 9.5,
367
+ marker_ms: float = 16,
368
+ ) -> None:
369
+ ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=60, cmap="Blues_r",
370
+ vmin=0, vmax=SIGMA_LEVELS[-1] * 3, extend="max", alpha=0.55)
371
+ for lv, co in zip(reversed(SIGMA_LEVELS), reversed(SIGMA_COLORS)):
372
+ ax.contourf(Om_grid, s8_grid, surface_sm.T,
373
+ levels=[0, lv], colors=[co], alpha=0.78)
374
+
375
+ cs = ax.contour(
376
+ Om_grid, s8_grid, surface_sm.T,
377
+ levels=SIGMA_LEVELS,
378
+ colors=["white", "white", "white"],
379
+ linewidths=[2.2, 1.6, 1.2],
380
+ linestyles=["-", "--", "-."],
381
+ )
382
+ ax.clabel(cs, fmt=SIGMA_LABELS, inline=True, fontsize=clabel_fontsize, colors="white")
383
+
384
+ ax.axvline(Om_true, color="red", lw=0.7, ls=":", alpha=0.6)
385
+ ax.axhline(s8_true, color="red", lw=0.7, ls=":", alpha=0.6)
386
+ ax.plot(Om_true, s8_true, "r+", ms=marker_ms, mew=2.5, zorder=6, label="True")
387
+ ax.plot(Om_pred, s8_pred, "w^", ms=max(6, marker_ms * 0.55), mew=1.2, zorder=6, label="MAP")
388
+ ax.set_xlim(Om_grid[0], Om_grid[-1])
389
+ ax.set_ylim(s8_grid[0], s8_grid[-1])
390
+ ax.grid(alpha=0.18)
391
+
392
+
393
+ def fig_posterior_L0_mosaic_3x3(
394
+ out_dir: Path,
395
+ n_fields: int,
396
+ out_path: Path,
397
+ mosaic_side_px: int = 10_000,
398
+ panel_inches: float = 4.0,
399
+ ) -> None:
400
+ from matplotlib.patches import Patch
401
+
402
+ n_plot = min(n_fields, 9)
403
+ fig_side = panel_inches * 3
404
+ dpi = mosaic_side_px / fig_side
405
+ fig, axes = plt.subplots(
406
+ 3, 3, figsize=(fig_side, fig_side), dpi=dpi,
407
+ squeeze=False,
408
+ )
409
+ for idx in range(9):
410
+ r, c = divmod(idx, 3)
411
+ ax = axes[r][c]
412
+ if idx >= n_plot:
413
+ ax.set_visible(False)
414
+ continue
415
+ nz = np.load(out_dir / f"field{idx:02d}_surfaces.npz")
416
+ L0 = np.asarray(nz["L_t0"])
417
+ Om_grid = np.asarray(nz["Om_grid"])
418
+ s8_grid = np.asarray(nz["s8_grid"])
419
+ Om_true = float(nz["Om_true"])
420
+ s8_true = float(nz["s8_true"])
421
+ surface_sm, _ = _L0_posterior_smoothed(L0, smooth_sigma=0.6)
422
+ _, _, (Om_pred, s8_pred) = marginal_from_neg2dL(surface_sm, Om_grid, s8_grid)
423
+ draw_L0_posterior_main_panel(
424
+ ax, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
425
+ clabel_fontsize=7.0, marker_ms=11,
426
+ )
427
+ if idx == 0:
428
+ legend_patches = [
429
+ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
430
+ Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
431
+ Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
432
+ ]
433
+ hs, ls_ = ax.get_legend_handles_labels()
434
+ ax.legend(
435
+ handles=legend_patches + hs,
436
+ labels=[p.get_label() for p in legend_patches] + ls_,
437
+ fontsize=6, loc="upper right", framealpha=0.9,
438
+ )
439
+ else:
440
+ leg = ax.get_legend()
441
+ if leg is not None:
442
+ leg.remove()
443
+ ax.set_title(
444
+ rf"field {idx}: $\Omega_m^{{\rm true}}={Om_true:.3f}$, $\sigma_8^{{\rm true}}={s8_true:.3f}$",
445
+ fontsize=8,
446
+ )
447
+ ax.set_xlabel(r"$\Omega_m$", fontsize=8)
448
+ ax.set_ylabel(r"$\sigma_8$", fontsize=8)
449
+
450
+ fig.suptitle(
451
+ r"VLB $L_0$ posterior (2D) — 9 test fields",
452
+ fontsize=11, fontweight="bold", y=0.995,
453
+ )
454
+ fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
455
+ plt.close(fig)
456
+ print(f" Saved -> {out_path} (≈ {mosaic_side_px}×{mosaic_side_px} px)")
457
+
458
+
459
+ def fig_main_posterior(
460
+ L0_surface: np.ndarray,
461
+ Om_grid: np.ndarray,
462
+ s8_grid: np.ndarray,
463
+ Om_true: float,
464
+ s8_true: float,
465
+ out_path: Path,
466
+ dpi: int = 200,
467
+ ):
468
+ from matplotlib.patches import Patch
469
+
470
+ surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
471
+
472
+ Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
473
+ surface_sm, Om_grid, s8_grid
474
+ )
475
+ Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
476
+ s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
477
+
478
+ fig = plt.figure(figsize=(8.5, 8.5), dpi=dpi)
479
+ gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4],
480
+ hspace=0.05, wspace=0.05,
481
+ left=0.10, right=0.95, top=0.95, bottom=0.08)
482
+ ax_main = fig.add_subplot(gs[1, 0])
483
+ ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
484
+ ax_rt = fig.add_subplot(gs[1, 1], sharey=ax_main)
485
+
486
+ draw_L0_posterior_main_panel(
487
+ ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
488
+ )
489
+
490
+ ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
491
+ ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
492
+
493
+ ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
494
+ ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
495
+ ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
496
+ ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
497
+ path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
498
+ ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
499
+ ax_top.set_yticks([])
500
+ ax_top.tick_params(labelbottom=False)
501
+ ax_top.set_title(
502
+ rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
503
+ rf" (true: {Om_true:.3f})",
504
+ fontsize=9,
505
+ )
506
+
507
+ ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
508
+ ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
509
+ ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
510
+ ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
511
+ path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
512
+ ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
513
+ ax_rt.set_xticks([])
514
+ ax_rt.tick_params(labelleft=False)
515
+ ax_rt.set_ylabel(
516
+ rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
517
+ rf" (true: {s8_true:.3f})",
518
+ fontsize=9, rotation=270, labelpad=15,
519
+ )
520
+ ax_rt.yaxis.set_label_position("right")
521
+
522
+ legend_patches = [
523
+ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
524
+ Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
525
+ Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
526
+ ]
527
+ hs, ls_ = ax_main.get_legend_handles_labels()
528
+ ax_main.legend(
529
+ handles=legend_patches + hs,
530
+ labels=[p.get_label() for p in legend_patches] + ls_,
531
+ fontsize=8, loc="upper right", framealpha=0.92,
532
+ )
533
+
534
+ fig.suptitle(
535
+ r"VLB Posterior using $L_0$ — joint and marginal distributions",
536
+ fontsize=10, fontweight="bold", y=0.99,
537
+ )
538
+ fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
539
+ plt.close(fig)
540
+ print(f" Saved -> {out_path}")
541
+ return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
542
+
543
+
544
+ def fig_pred_vs_true(pred_results: List[Dict], out_path: Path, dpi: int = 200) -> None:
545
+ Om_true = np.array([r["Om_true"] for r in pred_results])
546
+ s8_true = np.array([r["s8_true"] for r in pred_results])
547
+ Om_pred = np.array([r["Om_pred"] for r in pred_results])
548
+ s8_pred = np.array([r["s8_pred"] for r in pred_results])
549
+
550
+ Om_err_lo = np.array([r["Om_pred"] - r["Om_lo"] for r in pred_results])
551
+ Om_err_hi = np.array([r["Om_hi"] - r["Om_pred"] for r in pred_results])
552
+ s8_err_lo = np.array([r["s8_pred"] - r["s8_lo"] for r in pred_results])
553
+ s8_err_hi = np.array([r["s8_hi"] - r["s8_pred"] for r in pred_results])
554
+
555
+ rmse_Om = np.sqrt(((Om_pred - Om_true) ** 2).mean())
556
+ rmse_s8 = np.sqrt(((s8_pred - s8_true) ** 2).mean())
557
+
558
+ fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=dpi)
559
+
560
+ for ax, (true, pred, err_lo, err_hi, name, prange, rmse) in zip(axes, [
561
+ (Om_true, Om_pred, Om_err_lo, Om_err_hi, r"$\Omega_m$", (0.10, 0.50), rmse_Om),
562
+ (s8_true, s8_pred, s8_err_lo, s8_err_hi, r"$\sigma_8$", (0.60, 1.00), rmse_s8),
563
+ ]):
564
+ ax.errorbar(
565
+ true, pred, yerr=[np.maximum(err_lo, 0), np.maximum(err_hi, 0)],
566
+ fmt="o", color=GEN_COLOR, ecolor=SIGMA_COLORS[1],
567
+ elinewidth=1.2, capsize=3, ms=6,
568
+ label="DDPM-VLB inference (68% CI)",
569
+ )
570
+ ax.plot(prange, prange, "k--", lw=1.0, alpha=0.5, label="Identity")
571
+ ax.set_xlabel(f"True {name}", fontsize=11)
572
+ ax.set_ylabel(f"Predicted {name}", fontsize=11)
573
+ ax.set_xlim(*prange)
574
+ ax.set_ylim(*prange)
575
+ ax.grid(alpha=0.2)
576
+ ax.legend(fontsize=9, loc="lower right")
577
+ ax.text(
578
+ 0.04, 0.92, f"RMSE = {rmse:.4f}",
579
+ transform=ax.transAxes, fontsize=10,
580
+ bbox=dict(facecolor="white", edgecolor="#ccc", alpha=0.92, pad=4),
581
+ )
582
+ ax.set_title(f"{name}: predicted vs true", fontweight="bold", fontsize=10)
583
+
584
+ fig.suptitle(
585
+ "VLB Parameter Inference: predicted vs true\n"
586
+ r"Error bars = 68% CI from $L_0$ marginal posterior",
587
+ fontsize=10, fontweight="bold", y=1.01,
588
+ )
589
+ plt.tight_layout()
590
+ fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
591
+ plt.close(fig)
592
+ print(f" Saved -> {out_path}")
593
+ print(f" RMSE: Omega_m={rmse_Om:.4f} sigma_8={rmse_s8:.4f}")
594
+
595
+
596
+ def fig_posterior_and_contours_combined(
597
+ surfaces: Dict[int, np.ndarray],
598
+ L0_surface: np.ndarray,
599
+ Om_grid: np.ndarray,
600
+ s8_grid: np.ndarray,
601
+ Om_true: float,
602
+ s8_true: float,
603
+ out_path: Path,
604
+ dpi: int = 200,
605
+ ) -> Tuple[float, float, Tuple[float, float], Tuple[float, float]]:
606
+ """
607
+ Create a combined figure with contours_per_t on left and posterior on right.
608
+ """
609
+ from matplotlib.patches import Patch
610
+
611
+ surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
612
+ Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
613
+ surface_sm, Om_grid, s8_grid
614
+ )
615
+ Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
616
+ s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
617
+
618
+ fig = plt.figure(figsize=(16, 7), dpi=dpi)
619
+ gs = gridspec.GridSpec(2, 4, width_ratios=[4, 0.3, 4, 1], height_ratios=[1, 4],
620
+ hspace=0.08, wspace=0.10,
621
+ left=0.08, right=0.96, top=0.94, bottom=0.08)
622
+
623
+ # Left panel: Contours per timestep
624
+ ax_contours = fig.add_subplot(gs[:, 0])
625
+ cmap = plt.cm.viridis
626
+ n_t = len(surfaces)
627
+ colors = cmap(np.linspace(0.05, 0.95, n_t))
628
+
629
+ for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
630
+ neg2dL = 2.0 * (L_surf - L_surf.min())
631
+ ax_contours.contour(
632
+ Om_grid, s8_grid, neg2dL.T,
633
+ levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
634
+ )
635
+ ax_contours.plot([], [], color=col, lw=1.8, label=f"t={t}")
636
+
637
+ ax_contours.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
638
+ ax_contours.set_xlabel(r"$\Omega_m$", fontsize=12)
639
+ ax_contours.set_ylabel(r"$\sigma_8$", fontsize=12)
640
+ ax_contours.set_title(
641
+ r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contours per timestep",
642
+ fontweight="bold", fontsize=11,
643
+ )
644
+ ax_contours.legend(fontsize=8, loc="best", ncol=1, framealpha=0.92)
645
+ ax_contours.grid(alpha=0.18)
646
+ ax_contours.set_xlim(Om_grid[0], Om_grid[-1])
647
+ ax_contours.set_ylim(s8_grid[0], s8_grid[-1])
648
+
649
+ # Right panel: Posterior L_0 (similar to fig_main_posterior layout)
650
+ ax_main = fig.add_subplot(gs[1, 2])
651
+ ax_top = fig.add_subplot(gs[0, 2], sharex=ax_main)
652
+ ax_rt = fig.add_subplot(gs[1, 3], sharey=ax_main)
653
+
654
+ draw_L0_posterior_main_panel(
655
+ ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
656
+ )
657
+
658
+ ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
659
+ ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
660
+
661
+ ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
662
+ ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
663
+ ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
664
+ ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
665
+ path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
666
+ ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
667
+ ax_top.set_yticks([])
668
+ ax_top.tick_params(labelbottom=False)
669
+ ax_top.set_title(
670
+ rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
671
+ rf" (true: {Om_true:.3f})",
672
+ fontsize=9,
673
+ )
674
+
675
+ ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
676
+ ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
677
+ ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
678
+ ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
679
+ path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
680
+ ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
681
+ ax_rt.set_xticks([])
682
+ ax_rt.tick_params(labelleft=False)
683
+ ax_rt.set_ylabel(
684
+ rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
685
+ rf" (true: {s8_true:.3f})",
686
+ fontsize=9, rotation=270, labelpad=15,
687
+ )
688
+ ax_rt.yaxis.set_label_position("right")
689
+
690
+ legend_patches = [
691
+ Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
692
+ Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
693
+ Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
694
+ ]
695
+ hs, ls_ = ax_main.get_legend_handles_labels()
696
+ ax_main.legend(
697
+ handles=legend_patches + hs,
698
+ labels=[p.get_label() for p in legend_patches] + ls_,
699
+ fontsize=8, loc="upper right", framealpha=0.92,
700
+ )
701
+
702
+ fig.suptitle(
703
+ r"VLB Inference: $L_t$ contours (left) and $L_0$ posterior (right)",
704
+ fontsize=12, fontweight="bold", y=0.99,
705
+ )
706
+ fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
707
+ plt.close(fig)
708
+ print(f" Saved -> {out_path}")
709
+ return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
710
+
711
+
712
+ def parse_args() -> argparse.Namespace:
713
+ p = argparse.ArgumentParser(
714
+ description="VLB-based parameter inference for trained conditional DDPM.",
715
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
716
+ )
717
+ p.add_argument("--checkpoint", required=True)
718
+ p.add_argument("--training_args", default=None)
719
+ p.add_argument("--data_dir", default="./data/params_2")
720
+ p.add_argument("--output_dir", default="vlb_inference_outputs")
721
+ p.add_argument("--n_fields", type=int, default=9)
722
+ p.add_argument(
723
+ "--grid_size", type=int, default=10_000,
724
+ help="Ωm×σ8 evaluation grid resolution (each side). Values > 300 require "
725
+ "--allow_huge_grid (very long runs for large grid_size).",
726
+ )
727
+ p.add_argument(
728
+ "--allow_huge_grid", action="store_true",
729
+ help="Required when --grid_size > 300 (avoids accidental multi-week GPU jobs).",
730
+ )
731
+ p.add_argument(
732
+ "--mosaic_side_px", type=int, default=10_000,
733
+ help="Pixel width/height of posterior_L0_mosaic_3x3.png (square).",
734
+ )
735
+ p.add_argument(
736
+ "--mosaic_panel_inches", type=float, default=4.0,
737
+ help="Matplotlib size (inches) of each 3×3 panel; dpi = mosaic_side_px / (3× this).",
738
+ )
739
+ p.add_argument("--span", type=float, default=0.10)
740
+ p.add_argument("--t_subset", type=int, nargs="+",
741
+ default=[0, 1, 2, 5, 8, 10, 15, 20])
742
+ p.add_argument("--n_seeds", type=int, default=4)
743
+ p.add_argument("--batch_size", type=int, default=32)
744
+ p.add_argument("--device", default="auto")
745
+ p.add_argument("--seed", type=int, default=42)
746
+ p.add_argument("--dpi", type=int, default=200)
747
+ return p.parse_args()
748
+
749
+
750
+ def autodetect_args() -> Optional[str]:
751
+ for pat in ["outputs_conditional_*/args.json", "outputs_conditional_*/args.txt"]:
752
+ cands = sorted(Path(".").glob(pat), key=os.path.getctime, reverse=True)
753
+ if cands:
754
+ return str(cands[0])
755
+ return None
756
+
757
+
758
+ def main() -> None:
759
+ args = parse_args()
760
+ if args.grid_size > 300 and not args.allow_huge_grid:
761
+ print(
762
+ "\nRefusing --grid_size={} (> 300) without --allow_huge_grid.\n"
763
+ "A 10_000×10_000 grid is roughly (200)^2 ≈ 40_000× more forward passes per\n"
764
+ "field than 50×50. For a quick run use e.g. --grid_size 50; for a high-res\n"
765
+ "summary figure use the default --mosaic_side_px without increasing grid_size.\n"
766
+ "To proceed anyway: add --allow_huge_grid\n".format(args.grid_size)
767
+ )
768
+ raise SystemExit(2)
769
+ torch.manual_seed(args.seed)
770
+ np.random.seed(args.seed)
771
+
772
+ device = (
773
+ torch.device("cuda" if torch.cuda.is_available() else "cpu")
774
+ if args.device == "auto"
775
+ else torch.device(args.device)
776
+ )
777
+ print(f"\nDevice: {device}")
778
+
779
+ out = Path(args.output_dir)
780
+ out.mkdir(parents=True, exist_ok=True)
781
+
782
+ if args.training_args is None:
783
+ args.training_args = autodetect_args()
784
+ if args.training_args is None:
785
+ raise FileNotFoundError("Cannot find args.json — pass --training_args")
786
+ print(f" Auto-detected args: {args.training_args}")
787
+
788
+ cfg = load_config(args.training_args)
789
+ print("\nLoading model ...")
790
+ model = load_model(args.checkpoint, cfg, device)
791
+ n_p = sum(p.numel() for p in model.parameters())
792
+ print(f" Parameters: {n_p:,} T={model.diffusion.timesteps}")
793
+
794
+ print(f"\nLoading {args.n_fields} test fields ...")
795
+ test_imgs, test_labels, label_mu, label_std = load_test_data(
796
+ args.data_dir, args.n_fields, seed=args.seed,
797
+ )
798
+ print(f" Image shape: {test_imgs.shape[1:]}")
799
+ print(f" Label dim: {test_labels.shape[1]}")
800
+ print(f" Label μ/σ: {label_mu} / {label_std}")
801
+
802
+ print(
803
+ f"\nEvaluating L_t on {args.grid_size}x{args.grid_size} grid for "
804
+ f"{len(args.t_subset)} timesteps × {args.n_seeds} seeds ..."
805
+ )
806
+ print(
807
+ f" -> {args.grid_size ** 2 * len(args.t_subset) * args.n_seeds:,} "
808
+ f"forward-pass groups per field (× seeds averaged)"
809
+ )
810
+
811
+ pred_results = []
812
+ label_dim = int(cfg.get("label_dim", 2))
813
+
814
+ for fi in range(args.n_fields):
815
+ Om_true = float(test_labels[fi, 0])
816
+ s8_true = float(test_labels[fi, 1])
817
+ print(f"\n [{fi+1}/{args.n_fields}] field with "
818
+ f"Om={Om_true:.3f}, s8={s8_true:.3f}")
819
+
820
+ x_0 = torch.from_numpy(test_imgs[fi:fi + 1] * 2.0 - 1.0).unsqueeze(1).to(device)
821
+
822
+ Om_grid, s8_grid = build_eval_grid(Om_true, s8_true, args.grid_size, args.span)
823
+
824
+ t_start = time.time()
825
+ surfaces = evaluate_vlb_surface(
826
+ model=model,
827
+ x_0=x_0,
828
+ Om_grid=Om_grid,
829
+ s8_grid=s8_grid,
830
+ label_mu=label_mu,
831
+ label_std=label_std,
832
+ t_values=args.t_subset,
833
+ n_seeds=args.n_seeds,
834
+ batch_size=args.batch_size,
835
+ label_dim=label_dim,
836
+ fixed_seed=args.seed + fi,
837
+ device=device,
838
+ )
839
+ elapsed = time.time() - t_start
840
+ print(f" Evaluation time: {elapsed:.1f}s")
841
+
842
+ np.savez(
843
+ out / f"field{fi:02d}_surfaces.npz",
844
+ **{f"L_t{t}": s for t, s in surfaces.items()},
845
+ Om_grid=Om_grid, s8_grid=s8_grid,
846
+ Om_true=Om_true, s8_true=s8_true,
847
+ )
848
+
849
+ if 0 in surfaces:
850
+ # Combined figure: contours_per_t + posterior on same plot
851
+ Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) = fig_posterior_and_contours_combined(
852
+ surfaces, surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
853
+ out / f"field{fi:02d}_combined.png", dpi=args.dpi,
854
+ )
855
+ pred_results.append(dict(
856
+ Om_true=Om_true, s8_true=s8_true,
857
+ Om_pred=Om_pred, s8_pred=s8_pred,
858
+ Om_lo=Om_lo, Om_hi=Om_hi,
859
+ s8_lo=s8_lo, s8_hi=s8_hi,
860
+ ))
861
+
862
+ # Also save individual figures for detailed inspection
863
+ fig_contours_per_t(
864
+ surfaces, Om_grid, s8_grid, Om_true, s8_true,
865
+ out / f"field{fi:02d}_contours_per_t.png", dpi=args.dpi,
866
+ )
867
+
868
+ if 0 in surfaces:
869
+ fig_main_posterior(
870
+ surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
871
+ out / f"field{fi:02d}_posterior_L0.png", dpi=args.dpi,
872
+ )
873
+
874
+ if len(pred_results) >= 2:
875
+ fig_pred_vs_true(pred_results, out / "summary_pred_vs_true.png", dpi=args.dpi)
876
+ np.savez(
877
+ out / "summary.npz",
878
+ **{k: np.array([r[k] for r in pred_results])
879
+ for k in pred_results[0].keys()},
880
+ )
881
+
882
+ if args.n_fields >= 9 and all((out / f"field{i:02d}_surfaces.npz").is_file() for i in range(9)):
883
+ fig_posterior_L0_mosaic_3x3(
884
+ out, args.n_fields, out / "posterior_L0_mosaic_3x3.png",
885
+ mosaic_side_px=args.mosaic_side_px,
886
+ panel_inches=args.mosaic_panel_inches,
887
+ )
888
+
889
+ print(f"\nAll outputs -> {out.resolve()}/")
890
+ for f in sorted(out.glob("*.png")):
891
+ print(f" {f.name}")
892
+
893
+
894
+ if __name__ == "__main__":
895
+ main()
src/train_conditional.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for conditional diffusion on CAMELS LH (6 cosmological parameters).
3
+
4
+ Same training theory as DDPM_HI_Emulation_improved (2-label): DDPM noise prediction,
5
+ DDIM sampling, ConditionalUNet with time + label embeddings, label z-score from train split,
6
+ EMA, optional AMP, cosine LR, early stopping.
7
+
8
+ Changes from original:
9
+ - EMA weights are now applied before validation and sampling
10
+ - Training args are saved to args.txt for evaluation script
11
+ - Fixed --normalize_labels and --use_ddim flags (were un-disableable)
12
+ - Added mixed-precision (AMP) training support
13
+ - Fixed loss averaging to be per-sample rather than per-batch
14
+ - Added weights_only=True to torch.load for security
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import random
21
+ import time
22
+
23
+ import matplotlib.pyplot as plt
24
+ import numpy as np
25
+ import torch
26
+ import torch.optim as optim
27
+ from tqdm import tqdm
28
+
29
+ from dataset_conditional import DEFAULT_DATA_DIR, get_conditional_dataloaders
30
+ from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion
31
+ from unet_conditional import ConditionalUNet
32
+
33
+ # Weights & Biases (optional)
34
+ try:
35
+ import wandb
36
+
37
+ WANDB_AVAILABLE = True
38
+ except ImportError:
39
+ WANDB_AVAILABLE = False
40
+ print("Warning: wandb not available. Install with: pip install wandb")
41
+
42
+
43
+ class EMA:
44
+ """Exponential Moving Average for model parameters"""
45
+
46
+ def __init__(self, model, decay=0.9999):
47
+ self.model = model
48
+ self.decay = decay
49
+ self.shadow = {}
50
+ for name, param in model.named_parameters():
51
+ if param.requires_grad:
52
+ self.shadow[name] = param.data.clone()
53
+
54
+ def update(self):
55
+ for name, param in self.model.named_parameters():
56
+ if param.requires_grad:
57
+ self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
58
+
59
+ def apply_shadow(self):
60
+ self.backup = {
61
+ name: param.data.clone() for name, param in self.model.named_parameters() if param.requires_grad
62
+ }
63
+ for name, param in self.model.named_parameters():
64
+ if param.requires_grad:
65
+ param.data = self.shadow[name]
66
+
67
+ def restore(self):
68
+ for name, param in self.model.named_parameters():
69
+ if param.requires_grad:
70
+ param.data = self.backup[name]
71
+ self.backup = {}
72
+
73
+
74
+ def train_epoch(model, dataloader, optimizer, device, epoch, ema=None, use_wandb=False, scaler=None):
75
+ model.train()
76
+ total_loss = 0.0
77
+ total_samples = 0
78
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
79
+
80
+ for batch_idx, (images, labels) in enumerate(pbar):
81
+ images = images.to(device)
82
+ labels = labels.to(device)
83
+ batch_size = images.shape[0]
84
+
85
+ optimizer.zero_grad()
86
+
87
+ if scaler is not None:
88
+ with torch.amp.autocast("cuda"):
89
+ loss = model.get_loss(images, labels)
90
+ scaler.scale(loss).backward()
91
+ scaler.unscale_(optimizer)
92
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
93
+ scaler.step(optimizer)
94
+ scaler.update()
95
+ else:
96
+ loss = model.get_loss(images, labels)
97
+ loss.backward()
98
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
99
+ optimizer.step()
100
+
101
+ if ema is not None:
102
+ ema.update()
103
+
104
+ total_loss += loss.item() * batch_size
105
+ total_samples += batch_size
106
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
107
+
108
+ if use_wandb and batch_idx % 10 == 0:
109
+ wandb.log({"batch_loss": loss.item(), "epoch": epoch, "batch": epoch * len(dataloader) + batch_idx})
110
+
111
+ return total_loss / total_samples
112
+
113
+
114
+ def validate(model, dataloader, device):
115
+ model.eval()
116
+ total_loss = 0.0
117
+ total_samples = 0
118
+ with torch.no_grad():
119
+ for images, labels in tqdm(dataloader, desc="Validating"):
120
+ images = images.to(device)
121
+ labels = labels.to(device)
122
+ batch_size = images.shape[0]
123
+ loss = model.get_loss(images, labels)
124
+ total_loss += loss.item() * batch_size
125
+ total_samples += batch_size
126
+ return total_loss / total_samples
127
+
128
+
129
+ def save_checkpoint(model, optimizer, ema, epoch, loss, save_dir, is_best=False, last_improvement_epoch=None, scheduler=None):
130
+ checkpoint = {
131
+ "epoch": epoch,
132
+ "model_state_dict": model.state_dict(),
133
+ "optimizer_state_dict": optimizer.state_dict(),
134
+ "loss": loss,
135
+ }
136
+ if ema is not None:
137
+ checkpoint["ema_shadow"] = ema.shadow
138
+ if last_improvement_epoch is not None:
139
+ checkpoint["last_improvement_epoch"] = last_improvement_epoch
140
+ if scheduler is not None:
141
+ checkpoint["scheduler_state_dict"] = scheduler.state_dict()
142
+
143
+ torch.save(checkpoint, os.path.join(save_dir, "checkpoint_latest.pt"))
144
+ if is_best:
145
+ torch.save(checkpoint, os.path.join(save_dir, "best_model.pt"))
146
+ print(f"Saved best model at epoch {epoch+1}")
147
+
148
+ if (epoch + 1) % 20 == 0:
149
+ torch.save(checkpoint, os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt"))
150
+
151
+ print(f"Saved checkpoint at epoch {epoch+1}")
152
+
153
+
154
+ def sample_images(model, diffusion, device, save_path, test_labels, ema=None, n_samples=8, epoch=0, use_ddim=True, ddim_steps=50, use_wandb=False):
155
+ if ema is not None:
156
+ ema.apply_shadow()
157
+
158
+ model.eval()
159
+ labels = test_labels[:n_samples].to(device)
160
+
161
+ with torch.no_grad():
162
+ samples = diffusion.sample(
163
+ model,
164
+ labels=labels,
165
+ channels=1,
166
+ height=256,
167
+ width=256,
168
+ device=device,
169
+ progress=True,
170
+ use_ddim=use_ddim,
171
+ ddim_steps=ddim_steps,
172
+ eta=0.0,
173
+ )
174
+
175
+ if ema is not None:
176
+ ema.restore()
177
+
178
+ n_cols = min(n_samples, 4)
179
+ n_rows = (n_samples + n_cols - 1) // n_cols
180
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4.5 * n_rows))
181
+ if n_rows == 1 and n_cols == 1:
182
+ axes = np.array([[axes]])
183
+ elif n_rows == 1:
184
+ axes = axes[np.newaxis, :]
185
+ elif n_cols == 1:
186
+ axes = axes[:, np.newaxis]
187
+ for i in range(n_rows * n_cols):
188
+ ax = axes[i // n_cols, i % n_cols]
189
+ if i < n_samples:
190
+ img = samples[i, 0].cpu().numpy()
191
+ label_vals = labels[i].cpu().tolist()
192
+ label_str = ", ".join(f"{v:.2f}" for v in label_vals)
193
+ ax.imshow(img, vmin=-1, vmax=1)
194
+ ax.set_title(label_str, fontsize=10)
195
+ ax.axis("off")
196
+
197
+ plt.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=14)
198
+ plt.tight_layout()
199
+ plt.savefig(save_path, dpi=150, bbox_inches="tight")
200
+
201
+ if use_wandb:
202
+ wandb.log({"generated_samples": wandb.Image(save_path), "epoch": epoch})
203
+ plt.close()
204
+ print(f"Saved samples to {save_path}")
205
+
206
+
207
+ def save_training_args(args, output_dir):
208
+ """Save training arguments so the evaluation script can reconstruct the model."""
209
+ args_path = os.path.join(output_dir, "args.txt")
210
+ with open(args_path, "w", encoding="utf-8") as f:
211
+ for key, value in vars(args).items():
212
+ f.write(f"{key}: {value}\n")
213
+ args_json_path = os.path.join(output_dir, "args.json")
214
+ with open(args_json_path, "w", encoding="utf-8") as f:
215
+ json.dump(vars(args), f, indent=2)
216
+ print(f"Saved training args to {args_path} and {args_json_path}")
217
+
218
+
219
+ def main():
220
+ parser = argparse.ArgumentParser(description="Train conditional diffusion (LH 6-parameter)")
221
+ # Model
222
+ parser.add_argument("--label_dim", type=int, default=6)
223
+ parser.add_argument("--base_channels", type=int, default=64)
224
+ parser.add_argument("--channel_multipliers", type=int, nargs="+", default=[1, 2, 4, 8])
225
+ parser.add_argument("--attention_levels", type=int, nargs="+", default=[2, 3])
226
+ parser.add_argument("--dropout", type=float, default=0.1)
227
+ # Diffusion
228
+ parser.add_argument("--timesteps", type=int, default=1500)
229
+ parser.add_argument("--beta_start", type=float, default=1e-4)
230
+ parser.add_argument("--beta_end", type=float, default=0.02)
231
+ parser.add_argument("--schedule_type", type=str, default="linear")
232
+ # Training
233
+ parser.add_argument("--epochs", type=int, default=100)
234
+ parser.add_argument("--batch_size", type=int, default=8)
235
+ parser.add_argument("--lr", type=float, default=2e-4)
236
+ parser.add_argument("--ema_decay", type=float, default=0.9999)
237
+ parser.add_argument("--num_workers", type=int, default=4)
238
+ parser.add_argument("--early_stop_patience", type=int, default=30)
239
+ parser.add_argument(
240
+ "--use_amp",
241
+ action="store_true",
242
+ default=False,
243
+ help="Enable mixed-precision training (recommended for GPU)",
244
+ )
245
+ # Data
246
+ parser.add_argument(
247
+ "--data_dir",
248
+ type=str,
249
+ default=DEFAULT_DATA_DIR,
250
+ help="Directory with *_LH_6.npy and *_labels_LH.npy (same rule as improved repo: e.g. .../LH_data/params_6)",
251
+ )
252
+ parser.add_argument("--normalize_labels", action=argparse.BooleanOptionalAction, default=True)
253
+ # Output
254
+ parser.add_argument("--output_dir", type=str, default="outputs_conditional_6param")
255
+ parser.add_argument("--resume", type=str, default="")
256
+ parser.add_argument(
257
+ "--resume_refresh_scheduler",
258
+ action="store_true",
259
+ help="On resume, rebuild cosine LR scheduler for --epochs (last_epoch=start-1) instead of loading saved scheduler; use when extending training beyond the original epoch count",
260
+ )
261
+ parser.add_argument("--sample_every", type=int, default=10)
262
+ parser.add_argument("--use_ddim", action=argparse.BooleanOptionalAction, default=True)
263
+ parser.add_argument("--ddim_steps", type=int, default=50)
264
+ # WandB
265
+ parser.add_argument("--use_wandb", action="store_true", default=False)
266
+ parser.add_argument("--wandb_project", type=str, default="ddpm_cosmology")
267
+ parser.add_argument("--wandb_entity", type=str, default="")
268
+ parser.add_argument("--wandb_run_name", type=str, default="")
269
+
270
+ args = parser.parse_args()
271
+
272
+ seed = 42
273
+ random.seed(seed)
274
+ np.random.seed(seed)
275
+ torch.manual_seed(seed)
276
+ if torch.cuda.is_available():
277
+ torch.cuda.manual_seed_all(seed)
278
+ torch.backends.cudnn.deterministic = True
279
+ torch.backends.cudnn.benchmark = False
280
+
281
+ use_wandb = args.use_wandb and WANDB_AVAILABLE
282
+ if use_wandb:
283
+ run_name = args.wandb_run_name or f"conditional_diffusion_{time.strftime('%Y%m%d_%H%M%S')}"
284
+ wandb.init(project=args.wandb_project, entity=args.wandb_entity or None, name=run_name, config=vars(args))
285
+ print(f"W&B run: {run_name}")
286
+
287
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
288
+ output_dir = f"{args.output_dir}_{timestamp}"
289
+ os.makedirs(output_dir, exist_ok=True)
290
+ os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
291
+ os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True)
292
+
293
+ save_training_args(args, output_dir)
294
+
295
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
296
+ print(f"Using device: {device}")
297
+
298
+ scaler = torch.amp.GradScaler("cuda") if args.use_amp and torch.cuda.is_available() else None
299
+ if scaler:
300
+ print("Mixed-precision training enabled (AMP)")
301
+
302
+ print("\nLoading data...")
303
+ train_loader, val_loader, test_loader = get_conditional_dataloaders(
304
+ data_dir=args.data_dir,
305
+ batch_size=args.batch_size,
306
+ num_workers=args.num_workers,
307
+ normalize_labels=args.normalize_labels,
308
+ label_dim=args.label_dim,
309
+ )
310
+ _, test_labels = next(iter(test_loader))
311
+
312
+ print("\nCreating model...")
313
+ unet = ConditionalUNet(
314
+ in_channels=1,
315
+ out_channels=1,
316
+ label_dim=args.label_dim,
317
+ base_channels=args.base_channels,
318
+ channel_multipliers=args.channel_multipliers,
319
+ attention_levels=args.attention_levels,
320
+ dropout=args.dropout,
321
+ )
322
+ diffusion = GaussianDiffusion(
323
+ timesteps=args.timesteps,
324
+ beta_start=args.beta_start,
325
+ beta_end=args.beta_end,
326
+ schedule_type=args.schedule_type,
327
+ )
328
+
329
+ model = ConditionalDiffusionModel(unet, diffusion).to(device)
330
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
331
+
332
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
333
+ ema = EMA(model, decay=args.ema_decay)
334
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
335
+
336
+ start_epoch = 0
337
+ best_val_loss = float("inf")
338
+ last_improvement_epoch = -1
339
+ if args.resume:
340
+ print(f"Resuming from {args.resume}")
341
+ checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
342
+ model.load_state_dict(checkpoint["model_state_dict"])
343
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
344
+ if "ema_shadow" in checkpoint:
345
+ ema.shadow = checkpoint["ema_shadow"]
346
+ start_epoch = checkpoint["epoch"] + 1
347
+ best_val_loss = checkpoint.get("loss", float("inf"))
348
+ last_improvement_epoch = checkpoint.get("last_improvement_epoch", -1)
349
+ if args.resume_refresh_scheduler:
350
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
351
+ optimizer, T_max=args.epochs, last_epoch=start_epoch - 1
352
+ )
353
+ print(
354
+ f"Rebuilt LR scheduler for extended run: T_max={args.epochs}, "
355
+ f"resume at epoch {start_epoch + 1} (last_epoch={start_epoch - 1})"
356
+ )
357
+ elif "scheduler_state_dict" in checkpoint:
358
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
359
+
360
+ print("\nStarting training...")
361
+ losses = {"train": [], "val": []}
362
+
363
+ for epoch in range(start_epoch, args.epochs):
364
+ train_loss = train_epoch(model, train_loader, optimizer, device, epoch, ema, use_wandb, scaler=scaler)
365
+
366
+ if ema is not None:
367
+ ema.apply_shadow()
368
+ val_loss = validate(model, val_loader, device)
369
+ if ema is not None:
370
+ ema.restore()
371
+
372
+ losses["train"].append(train_loss)
373
+ losses["val"].append(val_loss)
374
+ scheduler.step()
375
+
376
+ if use_wandb:
377
+ wandb.log(
378
+ {
379
+ "epoch": epoch + 1,
380
+ "train_loss": train_loss,
381
+ "val_loss": val_loss,
382
+ "learning_rate": optimizer.param_groups[0]["lr"],
383
+ }
384
+ )
385
+
386
+ print(
387
+ f"\nEpoch {epoch+1}/{args.epochs} | Train: {train_loss:.6f} | Val: {val_loss:.6f} | "
388
+ f"LR: {optimizer.param_groups[0]['lr']:.6e}"
389
+ )
390
+
391
+ is_best = val_loss < best_val_loss
392
+ if is_best:
393
+ best_val_loss = val_loss
394
+ last_improvement_epoch = epoch
395
+
396
+ save_checkpoint(
397
+ model,
398
+ optimizer,
399
+ ema,
400
+ epoch,
401
+ val_loss,
402
+ os.path.join(output_dir, "checkpoints"),
403
+ is_best=is_best,
404
+ last_improvement_epoch=last_improvement_epoch,
405
+ scheduler=scheduler,
406
+ )
407
+
408
+ if epoch - last_improvement_epoch >= args.early_stop_patience:
409
+ print(f"Early stopping at epoch {epoch+1}")
410
+ break
411
+
412
+ if (epoch + 1) % args.sample_every == 0:
413
+ sample_path = os.path.join(output_dir, "samples", f"samples_epoch_{epoch+1}.png")
414
+ sample_images(
415
+ model,
416
+ diffusion,
417
+ device,
418
+ sample_path,
419
+ test_labels,
420
+ ema=ema,
421
+ epoch=epoch + 1,
422
+ use_ddim=args.use_ddim,
423
+ ddim_steps=args.ddim_steps,
424
+ use_wandb=use_wandb,
425
+ )
426
+
427
+ if (epoch + 1) % 5 == 0:
428
+ plt.figure(figsize=(10, 5))
429
+ plt.plot(losses["train"], label="Train Loss")
430
+ plt.plot(losses["val"], label="Val Loss")
431
+ plt.yscale("log")
432
+ plt.xlabel("Epoch")
433
+ plt.ylabel("Loss")
434
+ plt.title("Training Progress")
435
+ plt.legend()
436
+ plt.grid(True, alpha=0.3)
437
+ plt.savefig(os.path.join(output_dir, "losses.png"), dpi=150)
438
+ plt.close()
439
+
440
+ print(f"\nTraining completed! Best val loss: {best_val_loss:.6f}")
441
+ print(f"Results saved to: {output_dir}")
442
+ if use_wandb:
443
+ wandb.finish()
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()