Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
Browse files- README.md +45 -4
- cross_model/README.md +17 -0
- cross_model/check_poster_env.py +78 -0
- cross_model/compare_posterior_inference.py +699 -0
- cross_model/ddpm_posterior_corrected.py +867 -0
- cross_model/poster.py +1112 -0
- cross_model/run_compare_posterior.sh +52 -0
- cross_model/run_vlb_inference_1000grid.sh +81 -0
- cross_model/run_vlb_inference_200grid.sh +78 -0
- cross_model/scripts/compare_ddpm_models.py +855 -0
- cross_model/scripts/compare_ddpm_training_curves.py +45 -0
- cross_model/scripts/ddpm_figure6_integration.py +271 -0
- cross_model/scripts/ddpm_posterior_six_anchors.py +451 -0
- cross_model/scripts/ddpm_triangle_integration.py +194 -0
- cross_model/scripts/figure6_2409_style.py +157 -0
- cross_model/scripts/run_ddpm_comparison.sh +66 -0
- cross_model/scripts/run_ddpm_figure6.sh +27 -0
- cross_model/scripts/run_ddpm_figure6_suite.py +315 -0
- cross_model/scripts/run_ddpm_posterior_corrected.sh +58 -0
- cross_model/scripts/run_ddpm_posterior_six_anchors.sh +52 -0
- cross_model/scripts/run_poster.sh +53 -0
- cross_model/scripts/run_posterior_inference.sh +74 -0
- cross_model/scripts/run_triangle_ddpm_both.sh +75 -0
- cross_model/scripts/sigma_contour_utils.py +29 -0
- cross_model/scripts/triangle_plot_posterior.py +128 -0
- cross_model/submit_vlb_1000grid.py +106 -0
- scripts/shell/evaluate_conditional_lh6.sh +61 -0
- scripts/shell/plot_r2_cosmology_lhs.sh +72 -0
- scripts/shell/train_conditional_lh6.sh +60 -0
- src/eval_model.py +86 -0
- src/figure9_posterior.py +33 -0
- src/plot_r2_cosmology_lhs.py +316 -0
- src/posterior_inference.py +895 -0
- src/train_conditional.py +447 -0
README.md
CHANGED
|
@@ -26,17 +26,58 @@ This is the **best-validation checkpoint** from the training run under
|
|
| 26 |
|
| 27 |
## Files in this repo
|
| 28 |
|
|
|
|
|
|
|
| 29 |
| File | Purpose |
|
| 30 |
|------|---------|
|
| 31 |
| `model.pt` | PyTorch checkpoint (state-dict for `ConditionalDiffusionModel`) |
|
| 32 |
| `args.json` / `args.txt` | Training hyper-parameters and U-Net configuration |
|
| 33 |
| `config.json` | Architecture summary (for Hub discoverability) |
|
| 34 |
-
| `src/unet_conditional.py` | `ConditionalUNet` module |
|
| 35 |
-
| `src/diffusion_conditional.py` | `GaussianDiffusion` (DDPM + DDIM) and the wrapping `ConditionalDiffusionModel` |
|
| 36 |
-
| `src/dataset_conditional.py` | Helper for loading CAMELS LH data + label normalisation stats |
|
| 37 |
-
| `src/evaluate_conditional.py` | Reference evaluation pipeline (samples + metrics) |
|
| 38 |
| `inference_example.py` | Runnable example: downloads weights and generates a sample |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
## Architecture
|
| 41 |
|
| 42 |
Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from
|
|
|
|
| 26 |
|
| 27 |
## Files in this repo
|
| 28 |
|
| 29 |
+
**Top level**
|
| 30 |
+
|
| 31 |
| File | Purpose |
|
| 32 |
|------|---------|
|
| 33 |
| `model.pt` | PyTorch checkpoint (state-dict for `ConditionalDiffusionModel`) |
|
| 34 |
| `args.json` / `args.txt` | Training hyper-parameters and U-Net configuration |
|
| 35 |
| `config.json` | Architecture summary (for Hub discoverability) |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
| `inference_example.py` | Runnable example: downloads weights and generates a sample |
|
| 37 |
|
| 38 |
+
**`src/` — per-model Python**
|
| 39 |
+
|
| 40 |
+
| File | Purpose |
|
| 41 |
+
|------|---------|
|
| 42 |
+
| `train_conditional.py` | Training entry point (`label_dim=6`, mixed-precision) |
|
| 43 |
+
| `evaluate_conditional.py` | Held-out evaluation: samples + metrics |
|
| 44 |
+
| `eval_model.py` | Lightweight evaluation helper used by the figure scripts |
|
| 45 |
+
| `posterior_inference.py` | Full posterior-inference pipeline (likelihood / sampling) |
|
| 46 |
+
| `figure9_posterior.py` | Paper figure 9 (posterior triangle for the 6-param model) |
|
| 47 |
+
| `plot_r2_cosmology_lhs.py` | Latin-hypercube R² map (μ, σ vs cosmology) |
|
| 48 |
+
| `unet_conditional.py` | `ConditionalUNet` module |
|
| 49 |
+
| `diffusion_conditional.py` | `GaussianDiffusion` (DDPM + DDIM) and the wrapping `ConditionalDiffusionModel` |
|
| 50 |
+
| `dataset_conditional.py` | CAMELS LH dataset loader + label normalisation |
|
| 51 |
+
|
| 52 |
+
**`scripts/shell/` — SLURM launchers**
|
| 53 |
+
|
| 54 |
+
| File | Purpose |
|
| 55 |
+
|------|---------|
|
| 56 |
+
| `train_conditional_lh6.sh` | Submit a training job (`label_dim=6`) |
|
| 57 |
+
| `evaluate_conditional_lh6.sh` | Submit evaluation against the held-out test split |
|
| 58 |
+
| `plot_r2_cosmology_lhs.sh` | Generate the R² cosmology figure |
|
| 59 |
+
|
| 60 |
+
**`cross_model/` — posterior + comparison scripts that use BOTH models**
|
| 61 |
+
|
| 62 |
+
| File | Purpose |
|
| 63 |
+
|------|---------|
|
| 64 |
+
| `compare_posterior_inference.py` (+ `run_compare_posterior.sh`) | End-to-end posterior comparison between 2-param and 6-param emulators |
|
| 65 |
+
| `ddpm_posterior_corrected.py` (+ `scripts/run_ddpm_posterior_corrected.sh`) | Corrected DDPM posterior inference |
|
| 66 |
+
| `poster.py` / `check_poster_env.py` (+ `scripts/run_poster.sh`) | Posterior orchestration and environment check |
|
| 67 |
+
| `submit_vlb_1000grid.py` / `run_vlb_inference_*.sh` | Variational-lower-bound grid inference (200 / 1000 grid) |
|
| 68 |
+
| `scripts/compare_ddpm_models.py` (+ `run_ddpm_comparison.sh`) | DDPM-2 vs DDPM-6 comparison figures |
|
| 69 |
+
| `scripts/ddpm_posterior_six_anchors.py` (+ `run_ddpm_posterior_six_anchors.sh`) | Six-anchor posterior visualisation |
|
| 70 |
+
| `scripts/ddpm_figure6_integration.py`, `figure6_2409_style.py`, `run_ddpm_figure6_suite.py` (+ `run_ddpm_figure6.sh`) | Figure 6 generation pipeline |
|
| 71 |
+
| `scripts/ddpm_triangle_integration.py`, `triangle_plot_posterior.py` (+ `run_triangle_ddpm_both.sh`) | Triangle-plot posterior figures |
|
| 72 |
+
| `scripts/sigma_contour_utils.py` | Confidence-contour helper used by the figure scripts |
|
| 73 |
+
| `scripts/compare_ddpm_training_curves.py` | Parses SLURM logs for combined train/val loss plots |
|
| 74 |
+
| `cross_model/README.md` | How to point these scripts at locally-downloaded weights/data |
|
| 75 |
+
|
| 76 |
+
These cross-model scripts default to the original cluster paths (e.g.
|
| 77 |
+
`/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6`). After downloading
|
| 78 |
+
this repo, supply `--bundle-2param`, `--bundle-6param`, `--data-2param`,
|
| 79 |
+
`--data-6param` to override.
|
| 80 |
+
|
| 81 |
## Architecture
|
| 82 |
|
| 83 |
Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from
|
cross_model/README.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cross-model scripts
|
| 2 |
+
|
| 3 |
+
These posterior-inference and comparison scripts use BOTH the
|
| 4 |
+
2-parameter and 6-parameter DDPM checkpoints. Their default
|
| 5 |
+
paths assume the original cluster layout:
|
| 6 |
+
|
| 7 |
+
Models/
|
| 8 |
+
2param_DDPM_HI_Emulation/ <- code
|
| 9 |
+
6param_ddpm_hi_lh6/ <- code
|
| 10 |
+
notebook_model_weights/
|
| 11 |
+
2param_epoch200/ <- checkpoint + args.json
|
| 12 |
+
6param_best/ <- checkpoint + args.json
|
| 13 |
+
|
| 14 |
+
When running these scripts from a local download of this HF repo,
|
| 15 |
+
pass `--bundle-2param`, `--bundle-6param`, `--data-2param`,
|
| 16 |
+
`--data-6param` (etc.) to point at the locations where you placed
|
| 17 |
+
the weights and the CAMELS LH data. See each script's `--help`.
|
cross_model/check_poster_env.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""One-shot env check for poster.py — logs NDJSON for debug session."""
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
_LOG = "/scratch/mrpcol001/Diffusion_job/Models/.cursor/debug-a1359c.log"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _log(hypothesis_id: str, location: str, message: str, data: dict) -> None:
|
| 11 |
+
payload = {
|
| 12 |
+
"sessionId": "a1359c",
|
| 13 |
+
"runId": os.environ.get("DEBUG_POSTER_RUN", "pre-fix"),
|
| 14 |
+
"hypothesisId": hypothesis_id,
|
| 15 |
+
"location": location,
|
| 16 |
+
"message": message,
|
| 17 |
+
"data": data,
|
| 18 |
+
"timestamp": int(time.time() * 1000),
|
| 19 |
+
}
|
| 20 |
+
os.makedirs(os.path.dirname(_LOG), exist_ok=True)
|
| 21 |
+
with open(_LOG, "a", encoding="utf-8") as f:
|
| 22 |
+
f.write(json.dumps(payload) + "\n")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def main() -> None:
|
| 26 |
+
# region agent log
|
| 27 |
+
root = "/scratch/mrpcol001/Diffusion_job/Models"
|
| 28 |
+
cwd = os.getcwd()
|
| 29 |
+
poster_path = os.path.join(root, "poster.py")
|
| 30 |
+
poster_ci = os.path.join(root, "Poster.py")
|
| 31 |
+
_log(
|
| 32 |
+
"H2",
|
| 33 |
+
"check_poster_env.py:main",
|
| 34 |
+
"cwd vs expected Models root",
|
| 35 |
+
{"cwd": cwd, "root": root, "cwd_equals_root": os.path.abspath(cwd) == os.path.abspath(root)},
|
| 36 |
+
)
|
| 37 |
+
_log(
|
| 38 |
+
"H1",
|
| 39 |
+
"check_poster_env.py:main",
|
| 40 |
+
"poster.py presence",
|
| 41 |
+
{
|
| 42 |
+
"poster_py_exists": os.path.isfile(poster_path),
|
| 43 |
+
"poster_path": poster_path,
|
| 44 |
+
"size_if_exists": os.path.getsize(poster_path) if os.path.isfile(poster_path) else None,
|
| 45 |
+
},
|
| 46 |
+
)
|
| 47 |
+
_log(
|
| 48 |
+
"H3",
|
| 49 |
+
"check_poster_env.py:main",
|
| 50 |
+
"case variant",
|
| 51 |
+
{"Poster_py_exists": os.path.isfile(poster_ci)},
|
| 52 |
+
)
|
| 53 |
+
try:
|
| 54 |
+
names = sorted(os.listdir(root))
|
| 55 |
+
except OSError as e:
|
| 56 |
+
names = []
|
| 57 |
+
list_err = str(e)
|
| 58 |
+
else:
|
| 59 |
+
list_err = None
|
| 60 |
+
poster_like = [n for n in names if "poster" in n.lower()]
|
| 61 |
+
_log(
|
| 62 |
+
"H4",
|
| 63 |
+
"check_poster_env.py:main",
|
| 64 |
+
"Models directory poster-related names",
|
| 65 |
+
{"list_error": list_err, "poster_like_filenames": poster_like, "total_entries": len(names)},
|
| 66 |
+
)
|
| 67 |
+
_log(
|
| 68 |
+
"H5",
|
| 69 |
+
"check_poster_env.py:main",
|
| 70 |
+
"alternate runnable scripts hint",
|
| 71 |
+
{"scripts_dir_exists": os.path.isdir(os.path.join(root, "scripts"))},
|
| 72 |
+
)
|
| 73 |
+
# endregion agent log
|
| 74 |
+
print("check_poster_env: wrote logs to", _LOG)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
cross_model/compare_posterior_inference.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
compare_posterior_inference.py
|
| 4 |
+
==============================
|
| 5 |
+
Side-by-side corrected surrogate posteriors on (Omega_m, sigma_8) for
|
| 6 |
+
DDPM-2 and DDPM-6 using a JOINT P(k) + log-N_HI PDF Gaussian likelihood.
|
| 7 |
+
|
| 8 |
+
Why a joint summary statistic?
|
| 9 |
+
------------------------------
|
| 10 |
+
P(k) alone leaves Omega_m and sigma_8 strongly degenerate, which is why
|
| 11 |
+
prior runs returned sigma_8 ~ 0.80 +/- 0.12 regardless of truth. The
|
| 12 |
+
column-density PDF carries information that P(k) misses (it is sensitive
|
| 13 |
+
to non-Gaussian amplitude features), so combining the two breaks the
|
| 14 |
+
degeneracy along the S_8 direction.
|
| 15 |
+
|
| 16 |
+
What this script does
|
| 17 |
+
---------------------
|
| 18 |
+
1. Loads both DDPM-2 (epoch 200) and DDPM-6 (best) bundles from
|
| 19 |
+
notebook_model_weights/.
|
| 20 |
+
2. Calibrates per-model sigma_pk and sigma_pdf from validation-set DDPM
|
| 21 |
+
pair scatter (no hard-coded noise).
|
| 22 |
+
3. For each anchor in the test split:
|
| 23 |
+
- DDPM-2: direct (Omega_m, sigma_8) grid posterior.
|
| 24 |
+
- DDPM-6: Monte-Carlo marginalisation over the four astrophysical
|
| 25 |
+
nuisance parameters with a uniform LHS prior.
|
| 26 |
+
Both posteriors use the JOINT P(k) + PDF likelihood.
|
| 27 |
+
4. Saves per-anchor posterior arrays as .npz so plots can be re-rendered
|
| 28 |
+
cheaply later.
|
| 29 |
+
5. Emits a single comparison figure: rows = anchors, columns =
|
| 30 |
+
DDPM-2 | DDPM-6, with 68/95% credible contours, true value, posterior
|
| 31 |
+
mean, and posterior summary annotation.
|
| 32 |
+
|
| 33 |
+
Defaults
|
| 34 |
+
--------
|
| 35 |
+
--grid 30 --ddim-steps 50 --batch-size 8
|
| 36 |
+
--n-pk-samples 8 --n-marg-samples 20 --n-anchors 4
|
| 37 |
+
|
| 38 |
+
Quick smoke test
|
| 39 |
+
----------------
|
| 40 |
+
python compare_posterior_inference.py --grid 16 --n-pk-samples 4 \\
|
| 41 |
+
--n-marg-samples 5 --n-anchors 2
|
| 42 |
+
"""
|
| 43 |
+
from __future__ import annotations
|
| 44 |
+
|
| 45 |
+
import argparse
|
| 46 |
+
import gc
|
| 47 |
+
import sys
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
from typing import Dict, List, Optional, Tuple
|
| 50 |
+
|
| 51 |
+
import matplotlib
|
| 52 |
+
matplotlib.use("Agg")
|
| 53 |
+
import matplotlib.pyplot as plt
|
| 54 |
+
import numpy as np
|
| 55 |
+
import torch
|
| 56 |
+
|
| 57 |
+
MODELS_ROOT = Path(__file__).resolve().parent
|
| 58 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 59 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 60 |
+
sys.path.insert(0, str(CODE_6))
|
| 61 |
+
|
| 62 |
+
import evaluate_conditional as ec # noqa: E402
|
| 63 |
+
import eval_model as em # noqa: E402
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# =============================================================================
|
| 67 |
+
# 1. SUMMARY STATISTICS (P(k) and log-N_HI PDF, per map)
|
| 68 |
+
# =============================================================================
|
| 69 |
+
|
| 70 |
+
def per_map_log_pk(
|
| 71 |
+
imgs: np.ndarray,
|
| 72 |
+
box_size: float = 25.0,
|
| 73 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 74 |
+
"""log P(k) for a batch of [0,1] maps. Returns (dk, valid_mask, log_pks).
|
| 75 |
+
|
| 76 |
+
log_pks has shape (N, valid_mask.sum()).
|
| 77 |
+
"""
|
| 78 |
+
dk, pks = em.per_map_power_spectra_log(imgs, box_size)
|
| 79 |
+
valid = dk > 0
|
| 80 |
+
return dk, valid, np.log(pks[:, valid] + 1e-30)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def per_map_log_pdf(
|
| 84 |
+
imgs: np.ndarray,
|
| 85 |
+
log_nhi_min: float = 14.0,
|
| 86 |
+
log_nhi_max: float = 22.0,
|
| 87 |
+
n_bins: int = 100,
|
| 88 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 89 |
+
"""log column-density PDF for a batch of [0,1] maps.
|
| 90 |
+
|
| 91 |
+
Returns (bin_centers, log_pdfs). log_pdfs has shape (N, n_bins-1).
|
| 92 |
+
"""
|
| 93 |
+
imgs01 = np.clip(imgs, 0.0, 1.0)
|
| 94 |
+
bin_edges = np.linspace(log_nhi_min, log_nhi_max, n_bins)
|
| 95 |
+
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
|
| 96 |
+
pdfs = []
|
| 97 |
+
for img in imgs01:
|
| 98 |
+
vals = log_nhi_min + (log_nhi_max - log_nhi_min) * img.reshape(-1)
|
| 99 |
+
hist, _ = np.histogram(vals, bins=bin_edges, density=True)
|
| 100 |
+
pdfs.append(hist)
|
| 101 |
+
return bin_centers, np.log(np.stack(pdfs) + 1e-30)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# =============================================================================
|
| 105 |
+
# 2. CALIBRATION OF sigma_pk AND sigma_pdf
|
| 106 |
+
# =============================================================================
|
| 107 |
+
|
| 108 |
+
def calibrate_summary_sigmas(
|
| 109 |
+
model: torch.nn.Module,
|
| 110 |
+
images_val: np.ndarray,
|
| 111 |
+
labels_val: np.ndarray,
|
| 112 |
+
lab_mean: np.ndarray,
|
| 113 |
+
lab_std: np.ndarray,
|
| 114 |
+
normalize: bool,
|
| 115 |
+
device: torch.device,
|
| 116 |
+
box_size: float = 25.0,
|
| 117 |
+
ddim_steps: int = 50,
|
| 118 |
+
n_pairs: int = 30,
|
| 119 |
+
seed: int = 0,
|
| 120 |
+
) -> Tuple[float, float]:
|
| 121 |
+
"""Estimate sigma_pk and sigma_pdf from DDPM aleatoric scatter.
|
| 122 |
+
|
| 123 |
+
For each of n_pairs validation labels, draw two independent DDPM samples
|
| 124 |
+
and measure std(summary_a - summary_b) / sqrt(2). The median over pairs
|
| 125 |
+
gives a robust per-draw noise scale.
|
| 126 |
+
"""
|
| 127 |
+
rng = np.random.default_rng(seed)
|
| 128 |
+
n_val = min(n_pairs, len(labels_val))
|
| 129 |
+
idx = rng.choice(len(labels_val), size=n_val, replace=False)
|
| 130 |
+
H, W = int(images_val.shape[-2]), int(images_val.shape[-1])
|
| 131 |
+
|
| 132 |
+
sig_pk: List[float] = []
|
| 133 |
+
sig_pdf: List[float] = []
|
| 134 |
+
|
| 135 |
+
for i in idx:
|
| 136 |
+
lab_pair = np.repeat(labels_val[i:i + 1], 2, axis=0).astype(np.float32)
|
| 137 |
+
pair = em.sample_batch(
|
| 138 |
+
model, lab_pair, lab_mean, lab_std, normalize,
|
| 139 |
+
H, W, device, ddim_steps, False,
|
| 140 |
+
)
|
| 141 |
+
_, _, lpk = per_map_log_pk(pair, box_size)
|
| 142 |
+
_, lpdf = per_map_log_pdf(pair)
|
| 143 |
+
sig_pk.append(float(np.std(lpk[0] - lpk[1]) / np.sqrt(2.0)))
|
| 144 |
+
sig_pdf.append(float(np.std(lpdf[0] - lpdf[1]) / np.sqrt(2.0)))
|
| 145 |
+
|
| 146 |
+
s_pk = max(float(np.median(sig_pk)), 0.01)
|
| 147 |
+
s_pdf = max(float(np.median(sig_pdf)), 0.01)
|
| 148 |
+
print(f" sigma_pk median over {n_val} pairs = {s_pk:.4f}")
|
| 149 |
+
print(f" sigma_pdf median over {n_val} pairs = {s_pdf:.4f}")
|
| 150 |
+
return s_pk, s_pdf
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# =============================================================================
|
| 154 |
+
# 3. JOINT LOG-LIKELIHOOD (P(k) + PDF, averaged over DDPM stochasticity)
|
| 155 |
+
# =============================================================================
|
| 156 |
+
|
| 157 |
+
def joint_log_likelihood(
|
| 158 |
+
obs: np.ndarray,
|
| 159 |
+
full_labels: np.ndarray,
|
| 160 |
+
lab_mean: np.ndarray,
|
| 161 |
+
lab_std: np.ndarray,
|
| 162 |
+
normalize: bool,
|
| 163 |
+
model: torch.nn.Module,
|
| 164 |
+
device: torch.device,
|
| 165 |
+
H: int,
|
| 166 |
+
W: int,
|
| 167 |
+
box_size: float,
|
| 168 |
+
ddim_steps: int,
|
| 169 |
+
batch_sz: int,
|
| 170 |
+
n_pk_samples: int,
|
| 171 |
+
sigma_pk: float,
|
| 172 |
+
sigma_pdf: float,
|
| 173 |
+
) -> np.ndarray:
|
| 174 |
+
"""Average log P(k) and log PDF over n_pk_samples DDPM draws per grid pt.
|
| 175 |
+
|
| 176 |
+
Returns log L = log L_pk + log L_pdf at each grid point (shape (ngrid,)).
|
| 177 |
+
"""
|
| 178 |
+
# Observed summaries
|
| 179 |
+
_, valid_pk, log_pk_obs = per_map_log_pk(obs[np.newaxis], box_size)
|
| 180 |
+
log_pk_obs = log_pk_obs[0] # (n_valid_pk,)
|
| 181 |
+
_, log_pdf_obs = per_map_log_pdf(obs[np.newaxis])
|
| 182 |
+
log_pdf_obs = log_pdf_obs[0] # (n_pdf_bins,)
|
| 183 |
+
|
| 184 |
+
ngrid = full_labels.shape[0]
|
| 185 |
+
sum_lpk = np.zeros((ngrid, log_pk_obs.size), dtype=np.float64)
|
| 186 |
+
sum_lpdf = np.zeros((ngrid, log_pdf_obs.size), dtype=np.float64)
|
| 187 |
+
|
| 188 |
+
for _s in range(n_pk_samples):
|
| 189 |
+
for j0 in range(0, ngrid, batch_sz):
|
| 190 |
+
chunk = full_labels[j0: j0 + batch_sz]
|
| 191 |
+
imgs = em.sample_batch(
|
| 192 |
+
model, chunk, lab_mean, lab_std, normalize,
|
| 193 |
+
H, W, device, ddim_steps, False,
|
| 194 |
+
)
|
| 195 |
+
_, _, lpk = per_map_log_pk(imgs, box_size)
|
| 196 |
+
_, lpdf = per_map_log_pdf(imgs)
|
| 197 |
+
sum_lpk[j0: j0 + len(chunk)] += lpk
|
| 198 |
+
sum_lpdf[j0: j0 + len(chunk)] += lpdf
|
| 199 |
+
|
| 200 |
+
mean_lpk = sum_lpk / n_pk_samples
|
| 201 |
+
mean_lpdf = sum_lpdf / n_pk_samples
|
| 202 |
+
|
| 203 |
+
mse_pk = np.mean((log_pk_obs[None, :] - mean_lpk) ** 2, axis=1)
|
| 204 |
+
mse_pdf = np.mean((log_pdf_obs[None, :] - mean_lpdf) ** 2, axis=1)
|
| 205 |
+
|
| 206 |
+
return -mse_pk / (2.0 * sigma_pk ** 2) - mse_pdf / (2.0 * sigma_pdf ** 2)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# =============================================================================
|
| 210 |
+
# 4. GRIDS AND PER-MODEL POSTERIORS
|
| 211 |
+
# =============================================================================
|
| 212 |
+
|
| 213 |
+
def cosmo_grid_axes(
|
| 214 |
+
labels_ref: np.ndarray, grid: int, pad_frac: float = 0.02,
|
| 215 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 216 |
+
lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max())
|
| 217 |
+
lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max())
|
| 218 |
+
p0 = pad_frac * (hi0 - lo0 + 1e-12)
|
| 219 |
+
p1 = pad_frac * (hi1 - lo1 + 1e-12)
|
| 220 |
+
om_ax = np.linspace(lo0 - p0, hi0 + p0, grid, dtype=np.float32)
|
| 221 |
+
s8_ax = np.linspace(lo1 - p1, hi1 + p1, grid, dtype=np.float32)
|
| 222 |
+
return om_ax, s8_ax
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def posterior_ddpm2(
|
| 226 |
+
obs: np.ndarray,
|
| 227 |
+
labels_ref: np.ndarray,
|
| 228 |
+
lab_mean: np.ndarray, lab_std: np.ndarray,
|
| 229 |
+
normalize: bool,
|
| 230 |
+
model: torch.nn.Module, device: torch.device,
|
| 231 |
+
grid: int, batch_sz: int, ddim_steps: int,
|
| 232 |
+
n_pk_samples: int,
|
| 233 |
+
sigma_pk: float, sigma_pdf: float,
|
| 234 |
+
box_size: float = 25.0,
|
| 235 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 236 |
+
om_ax, s8_ax = cosmo_grid_axes(labels_ref, grid)
|
| 237 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 238 |
+
full = np.stack([OM.ravel(), S8.ravel()], axis=1).astype(np.float32)
|
| 239 |
+
H, W = int(obs.shape[-2]), int(obs.shape[-1])
|
| 240 |
+
|
| 241 |
+
log_w = joint_log_likelihood(
|
| 242 |
+
obs, full, lab_mean, lab_std, normalize, model, device,
|
| 243 |
+
H, W, box_size, ddim_steps, batch_sz,
|
| 244 |
+
n_pk_samples, sigma_pk, sigma_pdf,
|
| 245 |
+
)
|
| 246 |
+
log_w -= log_w.max()
|
| 247 |
+
w = np.exp(log_w).reshape(grid, grid)
|
| 248 |
+
w /= w.sum()
|
| 249 |
+
return w, OM, S8, om_ax, s8_ax
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def posterior_ddpm6(
|
| 253 |
+
obs: np.ndarray,
|
| 254 |
+
labels_ref: np.ndarray,
|
| 255 |
+
lab_mean: np.ndarray, lab_std: np.ndarray,
|
| 256 |
+
normalize: bool,
|
| 257 |
+
model: torch.nn.Module, device: torch.device,
|
| 258 |
+
lo_tail: np.ndarray, hi_tail: np.ndarray,
|
| 259 |
+
grid: int, batch_sz: int, ddim_steps: int,
|
| 260 |
+
n_pk_samples: int, n_marg_samples: int,
|
| 261 |
+
sigma_pk: float, sigma_pdf: float,
|
| 262 |
+
box_size: float = 25.0,
|
| 263 |
+
seed: int = 1,
|
| 264 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 265 |
+
rng = np.random.default_rng(seed)
|
| 266 |
+
om_ax, s8_ax = cosmo_grid_axes(labels_ref, grid)
|
| 267 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 268 |
+
ngrid = OM.size
|
| 269 |
+
H, W = int(obs.shape[-2]), int(obs.shape[-1])
|
| 270 |
+
|
| 271 |
+
log_acc = np.full(ngrid, -np.inf, dtype=np.float64)
|
| 272 |
+
for m in range(n_marg_samples):
|
| 273 |
+
theta_extra = rng.uniform(lo_tail, hi_tail).astype(np.float32)
|
| 274 |
+
full_6 = np.zeros((ngrid, 6), dtype=np.float32)
|
| 275 |
+
full_6[:, 0] = OM.ravel()
|
| 276 |
+
full_6[:, 1] = S8.ravel()
|
| 277 |
+
full_6[:, 2:6] = theta_extra
|
| 278 |
+
log_w = joint_log_likelihood(
|
| 279 |
+
obs, full_6, lab_mean, lab_std, normalize, model, device,
|
| 280 |
+
H, W, box_size, ddim_steps, batch_sz,
|
| 281 |
+
n_pk_samples, sigma_pk, sigma_pdf,
|
| 282 |
+
)
|
| 283 |
+
log_acc = np.logaddexp(log_acc, log_w)
|
| 284 |
+
if (m + 1) % 5 == 0 or (m + 1) == n_marg_samples:
|
| 285 |
+
print(f" DDPM-6 marg draw {m + 1}/{n_marg_samples}")
|
| 286 |
+
|
| 287 |
+
log_acc -= np.log(n_marg_samples)
|
| 288 |
+
log_acc -= log_acc.max()
|
| 289 |
+
w = np.exp(log_acc).reshape(grid, grid)
|
| 290 |
+
w /= w.sum()
|
| 291 |
+
return w, OM, S8, om_ax, s8_ax
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# =============================================================================
|
| 295 |
+
# 5. POSTERIOR DIAGNOSTICS
|
| 296 |
+
# =============================================================================
|
| 297 |
+
|
| 298 |
+
def credible_levels(
|
| 299 |
+
w: np.ndarray, levels: Tuple[float, ...] = (0.68, 0.95),
|
| 300 |
+
) -> List[float]:
|
| 301 |
+
sorted_w = np.sort(w.ravel())[::-1]
|
| 302 |
+
cumsum = np.cumsum(sorted_w)
|
| 303 |
+
out = []
|
| 304 |
+
total = float(sorted_w.sum())
|
| 305 |
+
for L in levels:
|
| 306 |
+
idx = int(np.searchsorted(cumsum, L * total))
|
| 307 |
+
idx = min(idx, len(sorted_w) - 1)
|
| 308 |
+
out.append(float(sorted_w[idx]))
|
| 309 |
+
return out
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def posterior_summary(
|
| 313 |
+
w: np.ndarray, OM: np.ndarray, S8: np.ndarray,
|
| 314 |
+
) -> Dict[str, float]:
|
| 315 |
+
w = w / w.sum()
|
| 316 |
+
mom = float((w * OM).sum())
|
| 317 |
+
ms8 = float((w * S8).sum())
|
| 318 |
+
sm = float(np.sqrt((w * (OM - mom) ** 2).sum()))
|
| 319 |
+
ss = float(np.sqrt((w * (S8 - ms8) ** 2).sum()))
|
| 320 |
+
S8_map = S8 * (OM / 0.3) ** 0.5
|
| 321 |
+
mS8 = float((w * S8_map).sum())
|
| 322 |
+
sS8 = float(np.sqrt((w * (S8_map - mS8) ** 2).sum()))
|
| 323 |
+
n_eff = float(1.0 / (w.ravel() ** 2).sum())
|
| 324 |
+
return dict(
|
| 325 |
+
om_mean=mom, om_std=sm, s8_mean=ms8, s8_std=ss,
|
| 326 |
+
S8_mean=mS8, S8_std=sS8, n_eff=n_eff,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# =============================================================================
|
| 331 |
+
# 6. PLOTTING
|
| 332 |
+
# =============================================================================
|
| 333 |
+
|
| 334 |
+
def plot_panel(
|
| 335 |
+
ax,
|
| 336 |
+
w: np.ndarray, OM: np.ndarray, S8: np.ndarray,
|
| 337 |
+
true_om: float, true_s8: float,
|
| 338 |
+
title: str,
|
| 339 |
+
summary: Dict[str, float],
|
| 340 |
+
cmap: str,
|
| 341 |
+
) -> None:
|
| 342 |
+
cf = ax.contourf(OM, S8, w, levels=16, cmap=cmap)
|
| 343 |
+
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
|
| 344 |
+
|
| 345 |
+
try:
|
| 346 |
+
thr68, thr95 = credible_levels(w, levels=(0.68, 0.95))
|
| 347 |
+
ax.contour(
|
| 348 |
+
OM, S8, w, levels=[thr95, thr68],
|
| 349 |
+
colors=["#e07b39", "#c0392b"],
|
| 350 |
+
linewidths=[1.0, 1.6],
|
| 351 |
+
linestyles=["--", "-"],
|
| 352 |
+
)
|
| 353 |
+
except Exception:
|
| 354 |
+
pass
|
| 355 |
+
|
| 356 |
+
ax.scatter(summary["om_mean"], summary["s8_mean"],
|
| 357 |
+
s=70, c="k", marker="+", zorder=8, lw=2.0)
|
| 358 |
+
ax.scatter(true_om, true_s8,
|
| 359 |
+
s=70, c="r", marker="x", zorder=8, lw=2.0)
|
| 360 |
+
|
| 361 |
+
info = (
|
| 362 |
+
f"$\\Omega_m$ = {summary['om_mean']:.3f} $\\pm$ {summary['om_std']:.3f}\n"
|
| 363 |
+
f"$\\sigma_8$ = {summary['s8_mean']:.3f} $\\pm$ {summary['s8_std']:.3f}\n"
|
| 364 |
+
f"$S_8$ = {summary['S8_mean']:.3f} $\\pm$ {summary['S8_std']:.3f}\n"
|
| 365 |
+
f"$n_{{\\rm eff}}$ = {summary['n_eff']:.0f}"
|
| 366 |
+
)
|
| 367 |
+
ax.text(
|
| 368 |
+
0.02, 0.02, info, transform=ax.transAxes, fontsize=7,
|
| 369 |
+
va="bottom", color="#111",
|
| 370 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.78),
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=9)
|
| 374 |
+
ax.set_ylabel(r"$\sigma_8$", fontsize=9)
|
| 375 |
+
ax.set_title(title, fontsize=8.5)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def make_comparison_figure(
|
| 379 |
+
per_anchor: List[Dict],
|
| 380 |
+
suptitle: str,
|
| 381 |
+
out_path: Path,
|
| 382 |
+
) -> None:
|
| 383 |
+
n = len(per_anchor)
|
| 384 |
+
fig, axes = plt.subplots(n, 2, figsize=(11, 4.5 * n), squeeze=False)
|
| 385 |
+
for i, p in enumerate(per_anchor):
|
| 386 |
+
plot_panel(
|
| 387 |
+
axes[i][0],
|
| 388 |
+
p["w2"], p["OM"], p["S8"],
|
| 389 |
+
p["true_om"], p["true_s8"],
|
| 390 |
+
(
|
| 391 |
+
f"DDPM-2 | ix={p['ix']} | "
|
| 392 |
+
rf"$\Omega_m$={p['true_om']:.3f}, $\sigma_8$={p['true_s8']:.3f}"
|
| 393 |
+
),
|
| 394 |
+
p["summ2"], cmap="Blues",
|
| 395 |
+
)
|
| 396 |
+
plot_panel(
|
| 397 |
+
axes[i][1],
|
| 398 |
+
p["w6"], p["OM"], p["S8"],
|
| 399 |
+
p["true_om"], p["true_s8"],
|
| 400 |
+
(
|
| 401 |
+
f"DDPM-6 (MC marg.) | ix={p['ix']} | "
|
| 402 |
+
rf"$\Omega_m$={p['true_om']:.3f}, $\sigma_8$={p['true_s8']:.3f}"
|
| 403 |
+
),
|
| 404 |
+
p["summ6"], cmap="Greens",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
from matplotlib.lines import Line2D
|
| 408 |
+
legend_h = [
|
| 409 |
+
Line2D([], [], marker="x", color="r", ls="", ms=8, label="True"),
|
| 410 |
+
Line2D([], [], marker="+", color="k", ls="", ms=8, label="Posterior mean"),
|
| 411 |
+
Line2D([], [], color="#c0392b", lw=1.6, label="68% CR"),
|
| 412 |
+
Line2D([], [], color="#e07b39", lw=1.0, ls="--", label="95% CR"),
|
| 413 |
+
]
|
| 414 |
+
fig.legend(
|
| 415 |
+
handles=legend_h, loc="upper center", ncol=4, fontsize=8.5,
|
| 416 |
+
bbox_to_anchor=(0.5, 0.998), frameon=False,
|
| 417 |
+
)
|
| 418 |
+
plt.suptitle(suptitle, fontsize=11, y=0.992)
|
| 419 |
+
plt.tight_layout(rect=(0, 0, 1, 0.97))
|
| 420 |
+
fig.savefig(out_path, dpi=160, bbox_inches="tight")
|
| 421 |
+
plt.close(fig)
|
| 422 |
+
print(f"Saved -> {out_path}")
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# =============================================================================
|
| 426 |
+
# 7. MODEL LOADING
|
| 427 |
+
# =============================================================================
|
| 428 |
+
|
| 429 |
+
def load_model(
|
| 430 |
+
args_json: Path, ckpt: Path, device: torch.device,
|
| 431 |
+
) -> Tuple[torch.nn.Module, Dict]:
|
| 432 |
+
cfg = ec.load_training_config(str(args_json))
|
| 433 |
+
model = ec.build_model(cfg, device)
|
| 434 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 435 |
+
model.eval()
|
| 436 |
+
return model, cfg
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
|
| 440 |
+
for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
|
| 441 |
+
p = data_dir / name
|
| 442 |
+
if p.is_file():
|
| 443 |
+
L = np.load(p)
|
| 444 |
+
if L.shape[1] < 6:
|
| 445 |
+
raise ValueError(
|
| 446 |
+
f"Expected >=6 label columns in {p}, got {L.shape}"
|
| 447 |
+
)
|
| 448 |
+
return (
|
| 449 |
+
L[:, 2:6].min(axis=0).astype(np.float32),
|
| 450 |
+
L[:, 2:6].max(axis=0).astype(np.float32),
|
| 451 |
+
)
|
| 452 |
+
raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
# =============================================================================
|
| 456 |
+
# 8. CLI
|
| 457 |
+
# =============================================================================
|
| 458 |
+
|
| 459 |
+
def parse_args() -> argparse.Namespace:
|
| 460 |
+
p = argparse.ArgumentParser(
|
| 461 |
+
description=(
|
| 462 |
+
"DDPM-2 vs DDPM-6 corrected posteriors on (Omega_m, sigma_8) "
|
| 463 |
+
"with a JOINT P(k) + log-N_HI PDF Gaussian likelihood. "
|
| 464 |
+
"sigma_pk and sigma_pdf are calibrated from DDPM aleatoric noise."
|
| 465 |
+
),
|
| 466 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 467 |
+
)
|
| 468 |
+
p.add_argument(
|
| 469 |
+
"--output-dir", type=Path,
|
| 470 |
+
default=MODELS_ROOT / "ddpm_posterior_compare_pk_pdf_out",
|
| 471 |
+
)
|
| 472 |
+
p.add_argument(
|
| 473 |
+
"--data-2param", type=Path,
|
| 474 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"),
|
| 475 |
+
)
|
| 476 |
+
p.add_argument(
|
| 477 |
+
"--data-6param", type=Path,
|
| 478 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"),
|
| 479 |
+
)
|
| 480 |
+
p.add_argument(
|
| 481 |
+
"--bundle-2param", type=Path,
|
| 482 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
|
| 483 |
+
)
|
| 484 |
+
p.add_argument(
|
| 485 |
+
"--bundle-6param", type=Path,
|
| 486 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
|
| 487 |
+
)
|
| 488 |
+
p.add_argument("--split", default="test", choices=["train", "val", "test"])
|
| 489 |
+
p.add_argument(
|
| 490 |
+
"--n-anchors", type=int, default=4,
|
| 491 |
+
help="Number of evenly-spaced test fields. 4 = compact figure; 6 = wider sweep.",
|
| 492 |
+
)
|
| 493 |
+
p.add_argument("--grid", type=int, default=30)
|
| 494 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 495 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 496 |
+
p.add_argument(
|
| 497 |
+
"--n-pk-samples", type=int, default=8,
|
| 498 |
+
help="DDPM draws averaged per grid point. Variance ~ 1/N. >=8 recommended.",
|
| 499 |
+
)
|
| 500 |
+
p.add_argument(
|
| 501 |
+
"--n-marg-samples", type=int, default=20,
|
| 502 |
+
help="MC draws over astrophysical params for DDPM-6. >=20 recommended.",
|
| 503 |
+
)
|
| 504 |
+
p.add_argument(
|
| 505 |
+
"--n-calib-pairs", type=int, default=30,
|
| 506 |
+
help="Validation pairs used to calibrate sigma_pk and sigma_pdf.",
|
| 507 |
+
)
|
| 508 |
+
p.add_argument(
|
| 509 |
+
"--sigma-pk", type=float, default=None,
|
| 510 |
+
help="Override calibrated sigma_pk (applied to BOTH models).",
|
| 511 |
+
)
|
| 512 |
+
p.add_argument(
|
| 513 |
+
"--sigma-pdf", type=float, default=None,
|
| 514 |
+
help="Override calibrated sigma_pdf (applied to BOTH models).",
|
| 515 |
+
)
|
| 516 |
+
p.add_argument("--seed", type=int, default=0)
|
| 517 |
+
return p.parse_args()
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# =============================================================================
|
| 521 |
+
# 9. MAIN
|
| 522 |
+
# =============================================================================
|
| 523 |
+
|
| 524 |
+
def main() -> None:
|
| 525 |
+
args = parse_args()
|
| 526 |
+
out_dir = args.output_dir.resolve()
|
| 527 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 528 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 529 |
+
|
| 530 |
+
print(f"Device : {device}")
|
| 531 |
+
print(f"Output : {out_dir}")
|
| 532 |
+
print()
|
| 533 |
+
|
| 534 |
+
# ── Data ─────────────────────────────────────────────────────────────
|
| 535 |
+
imgs2, labs2 = ec.load_split(args.data_2param, args.split)
|
| 536 |
+
mean2, std2 = ec.load_label_stats(args.data_2param)
|
| 537 |
+
imgs6, labs6 = ec.load_split(args.data_6param, args.split)
|
| 538 |
+
mean6, std6 = ec.load_label_stats(args.data_6param)
|
| 539 |
+
lo_tail, hi_tail = tail_lhs_bounds(args.data_6param)
|
| 540 |
+
|
| 541 |
+
print(f"DDPM-2 {args.split}: {len(labs2)} maps label_dim={labs2.shape[1]}")
|
| 542 |
+
print(f"DDPM-6 {args.split}: {len(labs6)} maps label_dim={labs6.shape[1]}")
|
| 543 |
+
print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
|
| 544 |
+
|
| 545 |
+
# Pick anchors evenly spread across the test split. Ranges of the two
|
| 546 |
+
# test splits may differ in length, so cap to the smaller.
|
| 547 |
+
n_pool = min(len(labs2), len(labs6))
|
| 548 |
+
n_anchors = max(2, min(args.n_anchors, n_pool))
|
| 549 |
+
anchor_ix = np.linspace(0, n_pool - 1, n_anchors, dtype=int)
|
| 550 |
+
print(f"Anchor indices: {anchor_ix.tolist()}")
|
| 551 |
+
print()
|
| 552 |
+
|
| 553 |
+
# ── Models ───────────────────────────────────────────────────────────
|
| 554 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 555 |
+
aj2 = args.bundle_2param / "args.json"
|
| 556 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 557 |
+
aj6 = args.bundle_6param / "args.json"
|
| 558 |
+
|
| 559 |
+
print("Loading DDPM-2 ...")
|
| 560 |
+
model2, cfg2 = load_model(aj2, ck2, device)
|
| 561 |
+
norm2 = bool(cfg2.get("normalize_labels", True))
|
| 562 |
+
print("Loading DDPM-6 ...")
|
| 563 |
+
model6, cfg6 = load_model(aj6, ck6, device)
|
| 564 |
+
norm6 = bool(cfg6.get("normalize_labels", True))
|
| 565 |
+
print()
|
| 566 |
+
|
| 567 |
+
# ── Calibrate sigma_pk and sigma_pdf per model ───────────────────────
|
| 568 |
+
if args.sigma_pk is not None and args.sigma_pdf is not None:
|
| 569 |
+
s2_pk = s6_pk = float(args.sigma_pk)
|
| 570 |
+
s2_pdf = s6_pdf = float(args.sigma_pdf)
|
| 571 |
+
print(f"sigma_pk overridden = {s2_pk:.4f}")
|
| 572 |
+
print(f"sigma_pdf overridden = {s2_pdf:.4f}")
|
| 573 |
+
else:
|
| 574 |
+
print("Calibrating DDPM-2 noise scales ...")
|
| 575 |
+
v2_imgs, v2_labs = ec.load_split(args.data_2param, "val")
|
| 576 |
+
s2_pk, s2_pdf = calibrate_summary_sigmas(
|
| 577 |
+
model2, v2_imgs, v2_labs, mean2, std2, norm2, device,
|
| 578 |
+
ddim_steps=args.ddim_steps,
|
| 579 |
+
n_pairs=args.n_calib_pairs, seed=args.seed,
|
| 580 |
+
)
|
| 581 |
+
del v2_imgs, v2_labs
|
| 582 |
+
gc.collect()
|
| 583 |
+
|
| 584 |
+
print("Calibrating DDPM-6 noise scales ...")
|
| 585 |
+
v6_imgs, v6_labs = ec.load_split(args.data_6param, "val")
|
| 586 |
+
s6_pk, s6_pdf = calibrate_summary_sigmas(
|
| 587 |
+
model6, v6_imgs, v6_labs, mean6, std6, norm6, device,
|
| 588 |
+
ddim_steps=args.ddim_steps,
|
| 589 |
+
n_pairs=args.n_calib_pairs, seed=args.seed + 7,
|
| 590 |
+
)
|
| 591 |
+
del v6_imgs, v6_labs
|
| 592 |
+
gc.collect()
|
| 593 |
+
|
| 594 |
+
if args.sigma_pk is not None:
|
| 595 |
+
s2_pk = s6_pk = float(args.sigma_pk)
|
| 596 |
+
if args.sigma_pdf is not None:
|
| 597 |
+
s2_pdf = s6_pdf = float(args.sigma_pdf)
|
| 598 |
+
|
| 599 |
+
print()
|
| 600 |
+
print(f"DDPM-2: sigma_pk={s2_pk:.4f} sigma_pdf={s2_pdf:.4f}")
|
| 601 |
+
print(f"DDPM-6: sigma_pk={s6_pk:.4f} sigma_pdf={s6_pdf:.4f}")
|
| 602 |
+
print()
|
| 603 |
+
|
| 604 |
+
# ── Per-anchor inference ─────────────────────────────────────────────
|
| 605 |
+
per_anchor: List[Dict] = []
|
| 606 |
+
for k, ix in enumerate(anchor_ix):
|
| 607 |
+
ix = int(ix)
|
| 608 |
+
obs2 = imgs2[ix]
|
| 609 |
+
obs6 = imgs6[ix]
|
| 610 |
+
# DDPM-2 labels carry (Omega_m, sigma_8) directly; DDPM-6 labels[ix]
|
| 611 |
+
# may differ in row ordering between the two splits, so we report the
|
| 612 |
+
# truth from each respective split when forming panel titles.
|
| 613 |
+
true_om = float(labs2[ix, 0])
|
| 614 |
+
true_s8 = float(labs2[ix, 1])
|
| 615 |
+
true_om6 = float(labs6[ix, 0])
|
| 616 |
+
true_s86 = float(labs6[ix, 1])
|
| 617 |
+
|
| 618 |
+
print(
|
| 619 |
+
f"[{k + 1}/{n_anchors}] ix={ix} "
|
| 620 |
+
f"DDPM-2 truth (Om={true_om:.3f}, s8={true_s8:.3f}) "
|
| 621 |
+
f"DDPM-6 truth (Om={true_om6:.3f}, s8={true_s86:.3f})"
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
print(" DDPM-2 posterior ...")
|
| 625 |
+
w2, OM, S8, om_ax, s8_ax = posterior_ddpm2(
|
| 626 |
+
obs2, labs2, mean2, std2, norm2, model2, device,
|
| 627 |
+
args.grid, args.batch_size, args.ddim_steps,
|
| 628 |
+
args.n_pk_samples, s2_pk, s2_pdf,
|
| 629 |
+
)
|
| 630 |
+
summ2 = posterior_summary(w2, OM, S8)
|
| 631 |
+
print(
|
| 632 |
+
f" Om={summ2['om_mean']:.3f}+/-{summ2['om_std']:.3f} "
|
| 633 |
+
f"s8={summ2['s8_mean']:.3f}+/-{summ2['s8_std']:.3f} "
|
| 634 |
+
f"S8={summ2['S8_mean']:.3f}+/-{summ2['S8_std']:.3f} "
|
| 635 |
+
f"n_eff={summ2['n_eff']:.0f}"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
print(" DDPM-6 posterior (MC marginalisation) ...")
|
| 639 |
+
w6, OM6, S8_6, om_ax6, s8_ax6 = posterior_ddpm6(
|
| 640 |
+
obs6, labs6, mean6, std6, norm6, model6, device,
|
| 641 |
+
lo_tail, hi_tail,
|
| 642 |
+
args.grid, args.batch_size, args.ddim_steps,
|
| 643 |
+
args.n_pk_samples, args.n_marg_samples, s6_pk, s6_pdf,
|
| 644 |
+
seed=args.seed + 1 + k,
|
| 645 |
+
)
|
| 646 |
+
summ6 = posterior_summary(w6, OM6, S8_6)
|
| 647 |
+
print(
|
| 648 |
+
f" Om={summ6['om_mean']:.3f}+/-{summ6['om_std']:.3f} "
|
| 649 |
+
f"s8={summ6['s8_mean']:.3f}+/-{summ6['s8_std']:.3f} "
|
| 650 |
+
f"S8={summ6['S8_mean']:.3f}+/-{summ6['S8_std']:.3f} "
|
| 651 |
+
f"n_eff={summ6['n_eff']:.0f}"
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
per_anchor.append(dict(
|
| 655 |
+
ix=ix,
|
| 656 |
+
true_om=true_om, true_s8=true_s8,
|
| 657 |
+
w2=w2, w6=w6, OM=OM, S8=S8,
|
| 658 |
+
summ2=summ2, summ6=summ6,
|
| 659 |
+
))
|
| 660 |
+
np.savez_compressed(
|
| 661 |
+
out_dir / f"posterior_ix{ix:03d}.npz",
|
| 662 |
+
w_ddpm2=w2, w_ddpm6=w6,
|
| 663 |
+
om_ax=om_ax, s8_ax=s8_ax,
|
| 664 |
+
true_om_ddpm2=true_om, true_s8_ddpm2=true_s8,
|
| 665 |
+
true_om_ddpm6=true_om6, true_s8_ddpm6=true_s86,
|
| 666 |
+
sigma_pk_ddpm2=s2_pk, sigma_pdf_ddpm2=s2_pdf,
|
| 667 |
+
sigma_pk_ddpm6=s6_pk, sigma_pdf_ddpm6=s6_pdf,
|
| 668 |
+
n_pk_samples=args.n_pk_samples,
|
| 669 |
+
n_marg_samples=args.n_marg_samples,
|
| 670 |
+
ddim_steps=args.ddim_steps,
|
| 671 |
+
)
|
| 672 |
+
gc.collect()
|
| 673 |
+
if torch.cuda.is_available():
|
| 674 |
+
torch.cuda.empty_cache()
|
| 675 |
+
|
| 676 |
+
# ── Comparison figure ────────────────────────────────────────────────
|
| 677 |
+
suptitle = (
|
| 678 |
+
r"Corrected joint-likelihood posteriors $(\Omega_m,\,\sigma_8)$"
|
| 679 |
+
" · "
|
| 680 |
+
r"$P(k)$ + $\log\,N_{\rm HI}$ PDF"
|
| 681 |
+
"\n"
|
| 682 |
+
f"DDPM-2 $\\sigma_{{Pk}}$={s2_pk:.3f}, $\\sigma_{{PDF}}$={s2_pdf:.3f}"
|
| 683 |
+
" · "
|
| 684 |
+
f"DDPM-6 (MC marg., $N_{{\\rm marg}}$={args.n_marg_samples}) "
|
| 685 |
+
f"$\\sigma_{{Pk}}$={s6_pk:.3f}, $\\sigma_{{PDF}}$={s6_pdf:.3f}"
|
| 686 |
+
" · "
|
| 687 |
+
f"grid {args.grid}², {args.n_pk_samples} DDPM draws/pt, "
|
| 688 |
+
f"DDIM {args.ddim_steps} steps"
|
| 689 |
+
)
|
| 690 |
+
make_comparison_figure(
|
| 691 |
+
per_anchor, suptitle,
|
| 692 |
+
out_dir / "compare_posterior_ddpm2_vs_ddpm6.png",
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
print(f"\nDone. Artifacts in {out_dir}")
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
if __name__ == "__main__":
|
| 699 |
+
main()
|
cross_model/ddpm_posterior_corrected.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Corrected Surrogate P(k) Bayesian Posteriors on (Omega_m, sigma_8).
|
| 4 |
+
|
| 5 |
+
═══════════════════════════════════════════════════════════════════
|
| 6 |
+
THEORETICAL CORRECTIONS FROM REVIEW
|
| 7 |
+
═══════════════════════════════════════════════════════════════════
|
| 8 |
+
1. STOCHASTIC LIKELIHOOD AVERAGING
|
| 9 |
+
- Single DDPM sample per grid point → n_ddpm_samples averaged draws
|
| 10 |
+
- Reduces emulator variance by 1/sqrt(N), preventing spurious multimodality
|
| 11 |
+
|
| 12 |
+
2. CALIBRATED LIKELIHOOD NOISE SCALE
|
| 13 |
+
- Hard-coded sigma=0.25 → sigma_pk estimated from validation-set scatter
|
| 14 |
+
- Uses aleatoric uncertainty of the DDPM emulator at fixed theta
|
| 15 |
+
|
| 16 |
+
3. PROPER MC MARGINALISATION (DDPM-6)
|
| 17 |
+
- Fixed dims 2-5 to extremes → Monte Carlo integral over prior
|
| 18 |
+
- p(Om,s8|d) = integral L(d|Om,s8,theta_extra) pi(theta_extra) dtheta_extra
|
| 19 |
+
- Approximated by uniform draws from the LHS training range
|
| 20 |
+
|
| 21 |
+
4. HIGHER GRID RESOLUTION
|
| 22 |
+
- 14x14 → 30x30 (900 pts), with optional adaptive refinement
|
| 23 |
+
|
| 24 |
+
5. VISUALISATION OF PRIOR, LIKELIHOOD, AND POSTERIOR SEPARATELY
|
| 25 |
+
- Shows all three distributions per anchor per model
|
| 26 |
+
- Includes 68%/95% credible contours
|
| 27 |
+
- Includes S8 = sigma_8*(Om/0.3)^0.5 derived parameter
|
| 28 |
+
- Includes effective sample size and posterior predictive check
|
| 29 |
+
|
| 30 |
+
6. PRIOR-POSTERIOR COMPARISON
|
| 31 |
+
- Explicit uniform prior overlaid on posterior for each anchor
|
| 32 |
+
|
| 33 |
+
USAGE
|
| 34 |
+
─────
|
| 35 |
+
# Run both models with full corrections
|
| 36 |
+
python ddpm_posterior_corrected.py
|
| 37 |
+
|
| 38 |
+
# DDPM-2 only, faster run
|
| 39 |
+
python ddpm_posterior_corrected.py --ddpm2-only --grid 20 --n-ddpm-samples 4
|
| 40 |
+
|
| 41 |
+
# DDPM-6 only with full marginalisation
|
| 42 |
+
python ddpm_posterior_corrected.py --ddpm6-only --n-marg-samples 30 --n-ddpm-samples 8
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
from __future__ import annotations
|
| 46 |
+
|
| 47 |
+
import argparse
|
| 48 |
+
import gc
|
| 49 |
+
import sys
|
| 50 |
+
import warnings
|
| 51 |
+
from pathlib import Path
|
| 52 |
+
from typing import Dict, List, Optional, Tuple
|
| 53 |
+
|
| 54 |
+
import matplotlib
|
| 55 |
+
matplotlib.use("Agg")
|
| 56 |
+
import matplotlib.gridspec as gridspec
|
| 57 |
+
import matplotlib.pyplot as plt
|
| 58 |
+
import numpy as np
|
| 59 |
+
import torch
|
| 60 |
+
|
| 61 |
+
# Resolve code + data relative to Models/ (same directory as this file)
|
| 62 |
+
MODELS_ROOT = Path(__file__).resolve().parent
|
| 63 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 64 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 65 |
+
sys.path.insert(0, str(CODE_6))
|
| 66 |
+
|
| 67 |
+
import evaluate_conditional as ec # noqa: E402
|
| 68 |
+
import eval_model as em # noqa: E402
|
| 69 |
+
from figure9_posterior import log_pk_observed # noqa: E402
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 73 |
+
# 1. DATA UTILITIES
|
| 74 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 75 |
+
|
| 76 |
+
def _train_label_path(data_dir: Path) -> Path:
|
| 77 |
+
for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
|
| 78 |
+
p = data_dir / name
|
| 79 |
+
if p.is_file():
|
| 80 |
+
return p
|
| 81 |
+
raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
|
| 85 |
+
"""
|
| 86 |
+
Min/max of the LHS training distribution for label dims 2-5.
|
| 87 |
+
These define the UNIFORM PRIOR for the astrophysical nuisance parameters.
|
| 88 |
+
"""
|
| 89 |
+
L = np.load(_train_label_path(data_dir))
|
| 90 |
+
if L.shape[1] < 6:
|
| 91 |
+
raise ValueError(f"Expected >= 6 label columns, got {L.shape}")
|
| 92 |
+
lo = L[:, 2:6].min(axis=0).astype(np.float32)
|
| 93 |
+
hi = L[:, 2:6].max(axis=0).astype(np.float32)
|
| 94 |
+
return lo, hi
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def cosmo_prior_bounds(labels_split: np.ndarray) -> Tuple[float, float, float, float]:
|
| 98 |
+
"""Return (om_lo, om_hi, s8_lo, s8_hi) from the training set LHS range."""
|
| 99 |
+
om_lo = float(labels_split[:, 0].min())
|
| 100 |
+
om_hi = float(labels_split[:, 0].max())
|
| 101 |
+
s8_lo = float(labels_split[:, 1].min())
|
| 102 |
+
s8_hi = float(labels_split[:, 1].max())
|
| 103 |
+
return om_lo, om_hi, s8_lo, s8_hi
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 107 |
+
# 2. LIKELIHOOD CALIBRATION
|
| 108 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 109 |
+
|
| 110 |
+
def calibrate_sigma_pk(
|
| 111 |
+
data_dir: Path,
|
| 112 |
+
model: torch.nn.Module,
|
| 113 |
+
lab_mean: np.ndarray,
|
| 114 |
+
lab_std: np.ndarray,
|
| 115 |
+
normalize: bool,
|
| 116 |
+
H: int,
|
| 117 |
+
W: int,
|
| 118 |
+
device: torch.device,
|
| 119 |
+
ddim_steps: int,
|
| 120 |
+
n_pairs: int = 60,
|
| 121 |
+
rng_seed: int = 0,
|
| 122 |
+
) -> float:
|
| 123 |
+
"""
|
| 124 |
+
Estimate sigma_pk = aleatoric std of log P(k) from DDPM stochasticity.
|
| 125 |
+
|
| 126 |
+
For n_pairs randomly chosen validation-set labels, draw 2 DDPM samples
|
| 127 |
+
and measure std(log Pk_a - log Pk_b) / sqrt(2). The median over pairs
|
| 128 |
+
gives a robust noise floor for the likelihood.
|
| 129 |
+
|
| 130 |
+
This replaces the hard-coded sigma = 0.25 with a data-driven estimate.
|
| 131 |
+
"""
|
| 132 |
+
print(f" Calibrating sigma_pk from {n_pairs} validation-set pairs ...")
|
| 133 |
+
images_val, labels_val = ec.load_split(data_dir, "val")
|
| 134 |
+
n_val = len(labels_val)
|
| 135 |
+
rng = np.random.default_rng(rng_seed)
|
| 136 |
+
idx = rng.choice(n_val, size=min(n_pairs, n_val), replace=False)
|
| 137 |
+
|
| 138 |
+
sigmas = []
|
| 139 |
+
for i in idx:
|
| 140 |
+
lab = np.repeat(labels_val[i : i + 1], 2, axis=0).astype(np.float32)
|
| 141 |
+
pair = em.sample_batch(
|
| 142 |
+
model, lab, lab_mean, lab_std, normalize,
|
| 143 |
+
H, W, device, ddim_steps, False,
|
| 144 |
+
)
|
| 145 |
+
_, pk_pair = em.per_map_power_spectra_log(pair, 25.0)
|
| 146 |
+
valid = pk_pair[0] > 0
|
| 147 |
+
log_pk_pair = np.log(pk_pair[:, valid] + 1e-30)
|
| 148 |
+
diff_std = float(np.std(log_pk_pair[0] - log_pk_pair[1])) / np.sqrt(2.0)
|
| 149 |
+
sigmas.append(diff_std)
|
| 150 |
+
|
| 151 |
+
sigma_cal = float(np.median(sigmas))
|
| 152 |
+
sigma_cal = max(sigma_cal, 0.05) # lower-bound: prevent degenerate likelihoods
|
| 153 |
+
print(f" Calibrated sigma_pk = {sigma_cal:.4f} (was hard-coded 0.25)")
|
| 154 |
+
return sigma_cal
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 158 |
+
# 3. GRID CONSTRUCTION
|
| 159 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 160 |
+
|
| 161 |
+
def build_cosmo_axes(
|
| 162 |
+
labels_split: np.ndarray, grid: int, pad_frac: float = 0.02
|
| 163 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 164 |
+
"""Return (om_ax, s8_ax) with `grid` equally-spaced points inside the LHS range."""
|
| 165 |
+
om_lo, om_hi, s8_lo, s8_hi = cosmo_prior_bounds(labels_split)
|
| 166 |
+
pad0 = pad_frac * (om_hi - om_lo + 1e-12)
|
| 167 |
+
pad1 = pad_frac * (s8_hi - s8_lo + 1e-12)
|
| 168 |
+
om_ax = np.linspace(om_lo - pad0, om_hi + pad0, grid)
|
| 169 |
+
s8_ax = np.linspace(s8_lo - pad1, s8_hi + pad1, grid)
|
| 170 |
+
return om_ax, s8_ax
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def build_full_grid(
|
| 174 |
+
om_ax: np.ndarray,
|
| 175 |
+
s8_ax: np.ndarray,
|
| 176 |
+
tail: Optional[np.ndarray] = None,
|
| 177 |
+
lab_dim: int = 2,
|
| 178 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 179 |
+
"""
|
| 180 |
+
Build a flattened (grid^2, lab_dim) array for a sweep over (Om, s8).
|
| 181 |
+
If tail is provided (shape (4,)), dims 2-5 are fixed to those values.
|
| 182 |
+
Returns (full_labels, OM_meshgrid, S8_meshgrid).
|
| 183 |
+
"""
|
| 184 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 185 |
+
ngrid = OM.size
|
| 186 |
+
out = np.zeros((ngrid, lab_dim), dtype=np.float32)
|
| 187 |
+
out[:, 0] = OM.ravel()
|
| 188 |
+
out[:, 1] = S8.ravel()
|
| 189 |
+
if tail is not None:
|
| 190 |
+
assert tail.shape == (4,), f"Expected tail shape (4,), got {tail.shape}"
|
| 191 |
+
out[:, 2:6] = tail[np.newaxis, :]
|
| 192 |
+
return out, OM, S8
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 196 |
+
# 4. LOG-LIKELIHOOD (AVERAGED OVER DDPM STOCHASTICITY)
|
| 197 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 198 |
+
|
| 199 |
+
def compute_log_likelihood(
|
| 200 |
+
obs: np.ndarray,
|
| 201 |
+
full: np.ndarray,
|
| 202 |
+
lab_mean: np.ndarray,
|
| 203 |
+
lab_std: np.ndarray,
|
| 204 |
+
normalize: bool,
|
| 205 |
+
model: torch.nn.Module,
|
| 206 |
+
H: int,
|
| 207 |
+
W: int,
|
| 208 |
+
device: torch.device,
|
| 209 |
+
ddim_steps: int,
|
| 210 |
+
batch_sz: int,
|
| 211 |
+
n_ddpm_samples: int,
|
| 212 |
+
sigma_pk: float,
|
| 213 |
+
) -> np.ndarray:
|
| 214 |
+
"""
|
| 215 |
+
Return log-likelihood array of shape (ngrid,), where each entry is:
|
| 216 |
+
|
| 217 |
+
ln L(d | theta_i) ≈ - (1 / 2*sigma^2) * mean_k[ (log Pk_obs - mean_j[log Pk_gen_j]) ]^2
|
| 218 |
+
|
| 219 |
+
The mean over j=1..n_ddpm_samples suppresses DDPM stochasticity.
|
| 220 |
+
|
| 221 |
+
Parameters
|
| 222 |
+
----------
|
| 223 |
+
sigma_pk : calibrated noise scale on log P(k) — NOT hard-coded.
|
| 224 |
+
n_ddpm_samples : number of independent DDPM draws to average per grid pt.
|
| 225 |
+
"""
|
| 226 |
+
ngrid = full.shape[0]
|
| 227 |
+
|
| 228 |
+
npix = int(obs.shape[-1])
|
| 229 |
+
dl = 25.0 / npix
|
| 230 |
+
logf = em.images01_to_log_nhi(obs)
|
| 231 |
+
dk, _ = ec.PowerSpectrum(logf, N=npix, dl=dl)
|
| 232 |
+
valid = dk > 0
|
| 233 |
+
# log_pk_observed returns values only at dk > 0 (same length as valid.sum())
|
| 234 |
+
log_pd = log_pk_observed(obs, 25.0, dk)
|
| 235 |
+
|
| 236 |
+
accumulated = []
|
| 237 |
+
for _s in range(n_ddpm_samples):
|
| 238 |
+
sample_log_pk = []
|
| 239 |
+
for j0 in range(0, ngrid, batch_sz):
|
| 240 |
+
chunk = full[j0: j0 + batch_sz]
|
| 241 |
+
imgs = em.sample_batch(
|
| 242 |
+
model, chunk, lab_mean, lab_std, normalize,
|
| 243 |
+
H, W, device, ddim_steps, False,
|
| 244 |
+
)
|
| 245 |
+
_, pkc = em.per_map_power_spectra_log(imgs, 25.0)
|
| 246 |
+
sample_log_pk.append(np.log(pkc[:, valid] + 1e-30))
|
| 247 |
+
accumulated.append(np.concatenate(sample_log_pk, axis=0)) # (ngrid, nk)
|
| 248 |
+
|
| 249 |
+
mean_log_pg = np.mean(accumulated, axis=0) # (ngrid, nk)
|
| 250 |
+
|
| 251 |
+
mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1) # (ngrid,)
|
| 252 |
+
log_like = -mse / (2.0 * sigma_pk ** 2)
|
| 253 |
+
return log_like
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 257 |
+
# 5. MARGINALISATION OVER ASTROPHYSICAL PARAMETERS (DDPM-6)
|
| 258 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 259 |
+
|
| 260 |
+
def marginal_log_likelihood_ddpm6(
|
| 261 |
+
obs: np.ndarray,
|
| 262 |
+
om_ax: np.ndarray,
|
| 263 |
+
s8_ax: np.ndarray,
|
| 264 |
+
lo_tail: np.ndarray,
|
| 265 |
+
hi_tail: np.ndarray,
|
| 266 |
+
lab_mean: np.ndarray,
|
| 267 |
+
lab_std: np.ndarray,
|
| 268 |
+
normalize: bool,
|
| 269 |
+
model: torch.nn.Module,
|
| 270 |
+
H: int,
|
| 271 |
+
W: int,
|
| 272 |
+
device: torch.device,
|
| 273 |
+
ddim_steps: int,
|
| 274 |
+
batch_sz: int,
|
| 275 |
+
n_ddpm_samples: int,
|
| 276 |
+
n_marg_samples: int,
|
| 277 |
+
sigma_pk: float,
|
| 278 |
+
rng_seed: int = 42,
|
| 279 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 280 |
+
"""
|
| 281 |
+
Correct 2D marginal log-likelihood over (Om, s8) for the 6-param DDPM.
|
| 282 |
+
|
| 283 |
+
Implements Monte Carlo integration:
|
| 284 |
+
ln L_marg(d | Om, s8) = log[ (1/N) Σ_i L(d | Om, s8, θ_extra^i) ]
|
| 285 |
+
= log-sum-exp [ ln L(d | Om, s8, θ_extra^i) ] - ln N
|
| 286 |
+
|
| 287 |
+
where θ_extra^i ~ Uniform(lo_tail, hi_tail) [the prior over dims 2-5]
|
| 288 |
+
|
| 289 |
+
Returns (log_like_marginal, OM, S8) with log_like_marginal shaped (ngrid,).
|
| 290 |
+
"""
|
| 291 |
+
rng = np.random.default_rng(rng_seed)
|
| 292 |
+
theta_extras = rng.uniform(
|
| 293 |
+
lo_tail, hi_tail, size=(n_marg_samples, 4)
|
| 294 |
+
).astype(np.float32)
|
| 295 |
+
|
| 296 |
+
ngrid = len(om_ax) * len(s8_ax)
|
| 297 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 298 |
+
|
| 299 |
+
log_like_accumulator = np.full(ngrid, -np.inf, dtype=np.float64)
|
| 300 |
+
|
| 301 |
+
for mc_i, theta_extra in enumerate(theta_extras):
|
| 302 |
+
print(f" MC marginalisation draw {mc_i+1}/{n_marg_samples} ...", end="\r", flush=True)
|
| 303 |
+
full, _, _ = build_full_grid(om_ax, s8_ax, tail=theta_extra, lab_dim=6)
|
| 304 |
+
lnL_i = compute_log_likelihood(
|
| 305 |
+
obs, full, lab_mean, lab_std, normalize, model,
|
| 306 |
+
H, W, device, ddim_steps, batch_sz, n_ddpm_samples, sigma_pk,
|
| 307 |
+
)
|
| 308 |
+
log_like_accumulator = np.logaddexp(log_like_accumulator, lnL_i)
|
| 309 |
+
|
| 310 |
+
log_like_marginal = log_like_accumulator - np.log(n_marg_samples)
|
| 311 |
+
print(flush=True)
|
| 312 |
+
return log_like_marginal, OM, S8
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 316 |
+
# 6. POSTERIOR COMPUTATION & DIAGNOSTICS
|
| 317 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 318 |
+
|
| 319 |
+
def log_like_to_posterior(
|
| 320 |
+
log_like: np.ndarray,
|
| 321 |
+
grid: int,
|
| 322 |
+
) -> Tuple[np.ndarray, float]:
|
| 323 |
+
"""Flat prior on grid → normalized weights + n_eff."""
|
| 324 |
+
log_like = log_like - log_like.max() # numerical stability
|
| 325 |
+
weights = np.exp(log_like).reshape(grid, grid)
|
| 326 |
+
weights /= weights.sum()
|
| 327 |
+
n_eff = 1.0 / float(np.sum(weights ** 2))
|
| 328 |
+
return weights, n_eff
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def credible_contour_levels(
|
| 332 |
+
weights: np.ndarray, credible_levels=(0.68, 0.95)
|
| 333 |
+
) -> List[float]:
|
| 334 |
+
"""Highest-density-style thresholds containing `level` of total mass."""
|
| 335 |
+
flat = weights.ravel()
|
| 336 |
+
total_mass = float(flat.sum())
|
| 337 |
+
sorted_desc = np.sort(flat)[::-1]
|
| 338 |
+
cumsum = np.cumsum(sorted_desc)
|
| 339 |
+
thresholds = []
|
| 340 |
+
for cl in credible_levels:
|
| 341 |
+
target = cl * total_mass
|
| 342 |
+
idx = int(np.searchsorted(cumsum, target))
|
| 343 |
+
idx = min(idx, len(sorted_desc) - 1)
|
| 344 |
+
thresholds.append(float(sorted_desc[idx]))
|
| 345 |
+
return thresholds
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def posterior_summary(weights: np.ndarray, OM: np.ndarray, S8: np.ndarray) -> Dict:
|
| 349 |
+
"""Posterior mean, std, and S8 = sigma_8*(Om/0.3)^0.5 statistics."""
|
| 350 |
+
mom = float((weights * OM).sum())
|
| 351 |
+
ms8 = float((weights * S8).sum())
|
| 352 |
+
var_om = float((weights * (OM - mom)**2).sum())
|
| 353 |
+
var_s8 = float((weights * (S8 - ms8)**2).sum())
|
| 354 |
+
S8_map = S8 * (OM / 0.3) ** 0.5
|
| 355 |
+
mS8 = float((weights * S8_map).sum())
|
| 356 |
+
var_S8 = float((weights * (S8_map - mS8)**2).sum())
|
| 357 |
+
return dict(
|
| 358 |
+
om_mean=mom, om_std=np.sqrt(var_om),
|
| 359 |
+
s8_mean=ms8, s8_std=np.sqrt(var_s8),
|
| 360 |
+
S8_mean=mS8, S8_std=np.sqrt(var_S8),
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 365 |
+
# 7. POSTERIOR PREDICTIVE CHECK
|
| 366 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 367 |
+
|
| 368 |
+
def posterior_predictive_check(
|
| 369 |
+
obs: np.ndarray,
|
| 370 |
+
weights: np.ndarray,
|
| 371 |
+
OM: np.ndarray,
|
| 372 |
+
S8: np.ndarray,
|
| 373 |
+
model: torch.nn.Module,
|
| 374 |
+
lab_mean: np.ndarray,
|
| 375 |
+
lab_std: np.ndarray,
|
| 376 |
+
normalize: bool,
|
| 377 |
+
H: int,
|
| 378 |
+
W: int,
|
| 379 |
+
device: torch.device,
|
| 380 |
+
ddim_steps: int,
|
| 381 |
+
n_draws: int = 30,
|
| 382 |
+
rng_seed: int = 7,
|
| 383 |
+
lab_dim: int = 2,
|
| 384 |
+
lo_tail: Optional[np.ndarray] = None,
|
| 385 |
+
hi_tail: Optional[np.ndarray] = None,
|
| 386 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 387 |
+
"""
|
| 388 |
+
Draw parameters from the posterior and generate maps.
|
| 389 |
+
Returns (k_valid, log_pk_obs_valid, ppc_lo68, ppc_hi68, ppc_lo95, ppc_hi95),
|
| 390 |
+
all aligned to dk > 0 bins only.
|
| 391 |
+
"""
|
| 392 |
+
rng = np.random.default_rng(rng_seed)
|
| 393 |
+
flat_w = weights.ravel()
|
| 394 |
+
flat_om = OM.ravel()
|
| 395 |
+
flat_s8 = S8.ravel()
|
| 396 |
+
idx = rng.choice(len(flat_w), size=n_draws, replace=True, p=flat_w)
|
| 397 |
+
|
| 398 |
+
npix = int(obs.shape[-1])
|
| 399 |
+
dl = 25.0 / npix
|
| 400 |
+
logf = em.images01_to_log_nhi(obs)
|
| 401 |
+
dk, _ = ec.PowerSpectrum(logf, N=npix, dl=dl)
|
| 402 |
+
valid = dk > 0
|
| 403 |
+
log_pk_obs = log_pk_observed(obs, 25.0, dk)
|
| 404 |
+
|
| 405 |
+
ppc_log_pks = []
|
| 406 |
+
for i in idx:
|
| 407 |
+
if lab_dim == 2:
|
| 408 |
+
theta = np.array([[flat_om[i], flat_s8[i]]], dtype=np.float32)
|
| 409 |
+
else:
|
| 410 |
+
assert lo_tail is not None and hi_tail is not None
|
| 411 |
+
te = rng.uniform(lo_tail, hi_tail).astype(np.float32)
|
| 412 |
+
theta = np.array([[flat_om[i], flat_s8[i], *te]], dtype=np.float32)
|
| 413 |
+
img = em.sample_batch(
|
| 414 |
+
model, theta, lab_mean, lab_std, normalize,
|
| 415 |
+
H, W, device, ddim_steps, False,
|
| 416 |
+
)
|
| 417 |
+
_, pkc = em.per_map_power_spectra_log(img, 25.0)
|
| 418 |
+
ppc_log_pks.append(np.log(pkc[0, valid] + 1e-30))
|
| 419 |
+
|
| 420 |
+
ppc_arr = np.array(ppc_log_pks) # (n_draws, n_valid)
|
| 421 |
+
return (
|
| 422 |
+
dk[valid],
|
| 423 |
+
log_pk_obs,
|
| 424 |
+
np.percentile(ppc_arr, 16, axis=0),
|
| 425 |
+
np.percentile(ppc_arr, 84, axis=0),
|
| 426 |
+
np.percentile(ppc_arr, 2.5, axis=0),
|
| 427 |
+
np.percentile(ppc_arr, 97.5, axis=0),
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 432 |
+
# 8. VISUALISATION
|
| 433 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 434 |
+
|
| 435 |
+
CMAP_PRIOR = "Greys"
|
| 436 |
+
CMAP_LIKE = "YlOrRd"
|
| 437 |
+
CMAP_POST = "Blues"
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _uniform_prior(OM: np.ndarray, S8: np.ndarray) -> np.ndarray:
|
| 441 |
+
"""Flat prior: uniform weight over the grid."""
|
| 442 |
+
prior = np.ones_like(OM)
|
| 443 |
+
return prior / prior.sum()
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def plot_prior_likelihood_posterior_panel(
|
| 447 |
+
fig,
|
| 448 |
+
gs_row,
|
| 449 |
+
weights: np.ndarray,
|
| 450 |
+
log_like: np.ndarray,
|
| 451 |
+
OM: np.ndarray,
|
| 452 |
+
S8: np.ndarray,
|
| 453 |
+
true_om: float,
|
| 454 |
+
true_s8: float,
|
| 455 |
+
anchor_ix: int,
|
| 456 |
+
summary: Dict,
|
| 457 |
+
n_eff: float,
|
| 458 |
+
title_suffix: str = "",
|
| 459 |
+
) -> None:
|
| 460 |
+
"""
|
| 461 |
+
Plot three side-by-side panels for one anchor:
|
| 462 |
+
[0] Uniform prior [1] Normalised likelihood [2] Posterior
|
| 463 |
+
Each panel shows 68% / 95% credible contours where applicable.
|
| 464 |
+
"""
|
| 465 |
+
grid = weights.shape[0]
|
| 466 |
+
prior = _uniform_prior(OM, S8)
|
| 467 |
+
|
| 468 |
+
like = np.exp(log_like - log_like.max()).reshape(grid, grid)
|
| 469 |
+
like /= like.sum()
|
| 470 |
+
|
| 471 |
+
panels = [
|
| 472 |
+
(prior, CMAP_PRIOR, r"Uniform Prior $\pi(\Omega_m, \sigma_8)$"),
|
| 473 |
+
(like, CMAP_LIKE, r"Normalised Likelihood $\mathcal{L}(\mathbf{d}|\theta)$"),
|
| 474 |
+
(weights, CMAP_POST, r"Posterior $p(\theta|\mathbf{d})$"),
|
| 475 |
+
]
|
| 476 |
+
|
| 477 |
+
for col, (Wmap, cmap, label) in enumerate(panels):
|
| 478 |
+
ax = fig.add_subplot(gs_row[col])
|
| 479 |
+
cf = ax.contourf(OM, S8, Wmap, levels=14, cmap=cmap)
|
| 480 |
+
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
|
| 481 |
+
|
| 482 |
+
if col > 0:
|
| 483 |
+
lvls = credible_contour_levels(Wmap)
|
| 484 |
+
try:
|
| 485 |
+
ax.contour(OM, S8, Wmap, levels=lvls,
|
| 486 |
+
colors=["white", "cyan"],
|
| 487 |
+
linewidths=[1.0, 0.6],
|
| 488 |
+
linestyles=["solid", "dashed"])
|
| 489 |
+
except Exception:
|
| 490 |
+
pass
|
| 491 |
+
|
| 492 |
+
ax.scatter(true_om, true_s8, s=70, c="red",
|
| 493 |
+
marker="x", zorder=8, linewidths=2.0, label="True")
|
| 494 |
+
if col == 2:
|
| 495 |
+
ax.scatter(summary["om_mean"], summary["s8_mean"],
|
| 496 |
+
s=80, c="black", marker="+",
|
| 497 |
+
zorder=8, linewidths=2.0, label="Post. mean")
|
| 498 |
+
ax.legend(fontsize=7, loc="upper right")
|
| 499 |
+
|
| 500 |
+
if col == 2:
|
| 501 |
+
txt = (
|
| 502 |
+
f"$\\Omega_m$: {summary['om_mean']:.3f} ± {summary['om_std']:.3f}\n"
|
| 503 |
+
f"$\\sigma_8$: {summary['s8_mean']:.3f} ± {summary['s8_std']:.3f}\n"
|
| 504 |
+
f"$S_8$: {summary['S8_mean']:.3f} ± {summary['S8_std']:.3f}\n"
|
| 505 |
+
f"$n_\\mathrm{{eff}}$: {n_eff:.0f}"
|
| 506 |
+
)
|
| 507 |
+
ax.text(0.02, 0.02, txt, transform=ax.transAxes,
|
| 508 |
+
fontsize=6.5, va="bottom", color="#111",
|
| 509 |
+
bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.75))
|
| 510 |
+
|
| 511 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=9)
|
| 512 |
+
ax.set_ylabel(r"$\sigma_8$", fontsize=9)
|
| 513 |
+
ax.set_title(
|
| 514 |
+
f"ix={anchor_ix} | {label}{title_suffix}",
|
| 515 |
+
fontsize=8, pad=4
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def plot_ppc_panel(
|
| 520 |
+
ax,
|
| 521 |
+
k_valid: np.ndarray,
|
| 522 |
+
log_pk_obs: np.ndarray,
|
| 523 |
+
ppc_lo68: np.ndarray,
|
| 524 |
+
ppc_hi68: np.ndarray,
|
| 525 |
+
ppc_lo95: np.ndarray,
|
| 526 |
+
ppc_hi95: np.ndarray,
|
| 527 |
+
anchor_ix: int,
|
| 528 |
+
) -> None:
|
| 529 |
+
"""Posterior predictive check on log P(k) at valid k bins."""
|
| 530 |
+
ax.fill_between(k_valid, ppc_lo95, ppc_hi95,
|
| 531 |
+
alpha=0.18, color="steelblue", label="95% PPC")
|
| 532 |
+
ax.fill_between(k_valid, ppc_lo68, ppc_hi68,
|
| 533 |
+
alpha=0.40, color="steelblue", label="68% PPC")
|
| 534 |
+
ax.plot(k_valid, log_pk_obs, "k-", lw=1.8, label="Observed")
|
| 535 |
+
ax.set_xscale("log")
|
| 536 |
+
ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8)
|
| 537 |
+
ax.set_ylabel(r"$\log P_\mathrm{HI}(k)$", fontsize=8)
|
| 538 |
+
ax.set_title(f"Posterior Predictive Check (ix={anchor_ix})", fontsize=8)
|
| 539 |
+
ax.legend(fontsize=7)
|
| 540 |
+
ax.grid(alpha=0.25)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def plot_s8_marginal(
|
| 545 |
+
ax, weights: np.ndarray, OM: np.ndarray, S8: np.ndarray,
|
| 546 |
+
true_om: float, true_s8: float, anchor_ix: int,
|
| 547 |
+
) -> None:
|
| 548 |
+
"""1D marginal of the derived S8 = sigma_8*(Om/0.3)^0.5 parameter."""
|
| 549 |
+
S8_map = S8 * (OM / 0.3) ** 0.5
|
| 550 |
+
true_S8 = true_s8 * (true_om / 0.3) ** 0.5
|
| 551 |
+
S8_flat = S8_map.ravel()
|
| 552 |
+
w_flat = weights.ravel()
|
| 553 |
+
s8_bins = np.linspace(S8_flat.min(), S8_flat.max(), 40)
|
| 554 |
+
hist = np.zeros(len(s8_bins) - 1)
|
| 555 |
+
for i, (lo, hi) in enumerate(zip(s8_bins[:-1], s8_bins[1:])):
|
| 556 |
+
mask = (S8_flat >= lo) & (S8_flat < hi)
|
| 557 |
+
hist[i] = w_flat[mask].sum()
|
| 558 |
+
hist /= hist.sum() * (s8_bins[1] - s8_bins[0])
|
| 559 |
+
centers = 0.5 * (s8_bins[:-1] + s8_bins[1:])
|
| 560 |
+
ax.bar(centers, hist, width=(s8_bins[1] - s8_bins[0]),
|
| 561 |
+
color="steelblue", alpha=0.7, label="Posterior")
|
| 562 |
+
ax.axvline(true_S8, color="red", lw=1.5, ls="--", label=f"True $S_8$={true_S8:.3f}")
|
| 563 |
+
ax.set_xlabel(r"$S_8 = \sigma_8(\Omega_m/0.3)^{0.5}$", fontsize=8)
|
| 564 |
+
ax.set_ylabel("Prob. density", fontsize=8)
|
| 565 |
+
ax.set_title(f"$S_8$ marginal (ix={anchor_ix})", fontsize=8)
|
| 566 |
+
ax.legend(fontsize=7)
|
| 567 |
+
ax.grid(alpha=0.25)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 571 |
+
# 9. PER-MODEL RUNNER
|
| 572 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 573 |
+
|
| 574 |
+
def run_model(
|
| 575 |
+
out_dir: Path,
|
| 576 |
+
*,
|
| 577 |
+
model_name: str,
|
| 578 |
+
model: torch.nn.Module,
|
| 579 |
+
cfg: Dict,
|
| 580 |
+
images: np.ndarray,
|
| 581 |
+
labels: np.ndarray,
|
| 582 |
+
lab_mean: np.ndarray,
|
| 583 |
+
lab_std: np.ndarray,
|
| 584 |
+
anchor_ix: np.ndarray,
|
| 585 |
+
grid: int,
|
| 586 |
+
ddim_steps: int,
|
| 587 |
+
batch_sz: int,
|
| 588 |
+
n_ddpm_samples: int,
|
| 589 |
+
n_marg_samples: int,
|
| 590 |
+
sigma_pk: float,
|
| 591 |
+
lo_tail: Optional[np.ndarray] = None,
|
| 592 |
+
hi_tail: Optional[np.ndarray] = None,
|
| 593 |
+
do_ppc: bool = True,
|
| 594 |
+
) -> None:
|
| 595 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 596 |
+
H, W = int(images.shape[-2]), int(images.shape[-1])
|
| 597 |
+
lab_dim = 6 if model_name == "DDPM-6" else 2
|
| 598 |
+
n_anchors = len(anchor_ix.ravel())
|
| 599 |
+
is_6param = lab_dim == 6
|
| 600 |
+
device = next(model.parameters()).device
|
| 601 |
+
|
| 602 |
+
fig1 = plt.figure(figsize=(17, 4.8 * n_anchors))
|
| 603 |
+
outer_gs = gridspec.GridSpec(n_anchors, 1, figure=fig1, hspace=0.55)
|
| 604 |
+
|
| 605 |
+
fig2, axes2 = plt.subplots(
|
| 606 |
+
n_anchors, 2, figsize=(12, 4.5 * n_anchors), squeeze=False
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
om_ax, s8_ax = build_cosmo_axes(labels, grid)
|
| 610 |
+
|
| 611 |
+
for k, ix in enumerate(anchor_ix.ravel()):
|
| 612 |
+
obs = images[ix]
|
| 613 |
+
lab_t = labels[ix].astype(np.float32)
|
| 614 |
+
true_om, true_s8 = float(lab_t[0]), float(lab_t[1])
|
| 615 |
+
|
| 616 |
+
print(f"\n[{model_name}] Anchor ix={ix} "
|
| 617 |
+
f"(Ωm={true_om:.3f}, σ8={true_s8:.3f})")
|
| 618 |
+
|
| 619 |
+
if is_6param:
|
| 620 |
+
print(" MC marginalisation over dims 2-5 ...")
|
| 621 |
+
log_like, OM, S8 = marginal_log_likelihood_ddpm6(
|
| 622 |
+
obs, om_ax, s8_ax, lo_tail, hi_tail,
|
| 623 |
+
lab_mean, lab_std, normalize, model,
|
| 624 |
+
H, W, device=device,
|
| 625 |
+
ddim_steps=ddim_steps, batch_sz=batch_sz,
|
| 626 |
+
n_ddpm_samples=n_ddpm_samples,
|
| 627 |
+
n_marg_samples=n_marg_samples,
|
| 628 |
+
sigma_pk=sigma_pk,
|
| 629 |
+
)
|
| 630 |
+
else:
|
| 631 |
+
print(f" Computing log-likelihood ({n_ddpm_samples} DDPM draws per point) ...")
|
| 632 |
+
full, OM, S8 = build_full_grid(om_ax, s8_ax, tail=None, lab_dim=2)
|
| 633 |
+
log_like = compute_log_likelihood(
|
| 634 |
+
obs, full, lab_mean, lab_std, normalize, model,
|
| 635 |
+
H, W, device, ddim_steps, batch_sz, n_ddpm_samples, sigma_pk,
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
weights, n_eff = log_like_to_posterior(log_like, grid)
|
| 639 |
+
summary = posterior_summary(weights, OM, S8)
|
| 640 |
+
|
| 641 |
+
print(f" n_eff = {n_eff:.1f} / {grid**2} grid points")
|
| 642 |
+
print(f" Ωm posterior: {summary['om_mean']:.3f} ± {summary['om_std']:.3f} "
|
| 643 |
+
f"(true: {true_om:.3f})")
|
| 644 |
+
print(f" σ8 posterior: {summary['s8_mean']:.3f} ± {summary['s8_std']:.3f} "
|
| 645 |
+
f"(true: {true_s8:.3f})")
|
| 646 |
+
print(f" S8 posterior: {summary['S8_mean']:.3f} ± {summary['S8_std']:.3f}")
|
| 647 |
+
|
| 648 |
+
if n_eff < 20:
|
| 649 |
+
warnings.warn(
|
| 650 |
+
f"n_eff={n_eff:.1f} is very low for ix={ix}. "
|
| 651 |
+
"Increase n_ddpm_samples or grid resolution.",
|
| 652 |
+
stacklevel=2,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
inner_gs = gridspec.GridSpecFromSubplotSpec(
|
| 656 |
+
1, 3, subplot_spec=outer_gs[k], wspace=0.38
|
| 657 |
+
)
|
| 658 |
+
marg_note = " (marginalised)" if is_6param else ""
|
| 659 |
+
plot_prior_likelihood_posterior_panel(
|
| 660 |
+
fig1, inner_gs, weights, log_like.reshape(grid, grid),
|
| 661 |
+
OM, S8, true_om, true_s8, ix, summary, n_eff,
|
| 662 |
+
title_suffix=marg_note,
|
| 663 |
+
)
|
| 664 |
+
|
| 665 |
+
if do_ppc:
|
| 666 |
+
dk_v, log_pk_obs, plo68, phi68, plo95, phi95 = posterior_predictive_check(
|
| 667 |
+
obs, weights, OM, S8, model,
|
| 668 |
+
lab_mean, lab_std, normalize,
|
| 669 |
+
H, W, device, ddim_steps,
|
| 670 |
+
n_draws=20, lab_dim=lab_dim,
|
| 671 |
+
lo_tail=lo_tail, hi_tail=hi_tail,
|
| 672 |
+
)
|
| 673 |
+
plot_ppc_panel(
|
| 674 |
+
axes2[k, 0], dk_v, log_pk_obs, plo68, phi68, plo95, phi95, ix,
|
| 675 |
+
)
|
| 676 |
+
else:
|
| 677 |
+
axes2[k, 0].axis("off")
|
| 678 |
+
axes2[k, 0].text(
|
| 679 |
+
0.5, 0.5, "PPC disabled",
|
| 680 |
+
ha="center", va="center", transform=axes2[k, 0].transAxes,
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
plot_s8_marginal(axes2[k, 1], weights, OM, S8, true_om, true_s8, ix)
|
| 684 |
+
|
| 685 |
+
fig1.suptitle(
|
| 686 |
+
f"{model_name} — Prior · Likelihood · Posterior on "
|
| 687 |
+
r"$(\Omega_m,\,\sigma_8)$ — six CAMELS anchors"
|
| 688 |
+
f"\n[sigma_pk={sigma_pk:.3f}, n_ddpm={n_ddpm_samples}, grid={grid}×{grid}"
|
| 689 |
+
+ (f", n_marg={n_marg_samples}]" if is_6param else "]"),
|
| 690 |
+
fontsize=12, y=1.001,
|
| 691 |
+
)
|
| 692 |
+
p1 = out_dir / f"{model_name.replace('-', '')}_prior_likelihood_posterior.png"
|
| 693 |
+
fig1.savefig(p1, dpi=160, bbox_inches="tight")
|
| 694 |
+
plt.close(fig1)
|
| 695 |
+
print(f"\nSaved → {p1}")
|
| 696 |
+
|
| 697 |
+
fig2.suptitle(
|
| 698 |
+
f"{model_name} — Posterior Predictive Checks & $S_8$ Marginals",
|
| 699 |
+
fontsize=12, y=1.002,
|
| 700 |
+
)
|
| 701 |
+
fig2.tight_layout()
|
| 702 |
+
p2 = out_dir / f"{model_name.replace('-', '')}_ppc_s8.png"
|
| 703 |
+
fig2.savefig(p2, dpi=160, bbox_inches="tight")
|
| 704 |
+
plt.close(fig2)
|
| 705 |
+
print(f"Saved → {p2}")
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def load_model(args_json: Path, ckpt: Path, device: torch.device):
|
| 709 |
+
cfg = ec.load_training_config(str(args_json))
|
| 710 |
+
model = ec.build_model(cfg, device)
|
| 711 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 712 |
+
model.eval()
|
| 713 |
+
return model, cfg
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def parse_args() -> argparse.Namespace:
|
| 717 |
+
p = argparse.ArgumentParser(
|
| 718 |
+
description=(
|
| 719 |
+
"Corrected Bayesian posteriors on (Ωm, σ8): averaged DDPM likelihood, "
|
| 720 |
+
"calibrated sigma_pk, DDPM-6 MC marginalisation, prior/likelihood/posterior plots."
|
| 721 |
+
),
|
| 722 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 723 |
+
)
|
| 724 |
+
p.add_argument("--output-dir", type=Path,
|
| 725 |
+
default=MODELS_ROOT / "ddpm_posterior_corrected_fullviz_out")
|
| 726 |
+
p.add_argument("--data-2param", type=Path,
|
| 727 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
|
| 728 |
+
p.add_argument("--data-6param", type=Path,
|
| 729 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
|
| 730 |
+
p.add_argument("--bundle-2param", type=Path,
|
| 731 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200")
|
| 732 |
+
p.add_argument("--bundle-6param", type=Path,
|
| 733 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best")
|
| 734 |
+
p.add_argument("--split", type=str, default="test",
|
| 735 |
+
choices=["train", "val", "test"])
|
| 736 |
+
p.add_argument("--grid", type=int, default=30)
|
| 737 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 738 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 739 |
+
p.add_argument("--n-ddpm-samples", type=int, default=8)
|
| 740 |
+
p.add_argument("--n-marg-samples", type=int, default=20)
|
| 741 |
+
p.add_argument("--sigma-pk", type=float, default=None)
|
| 742 |
+
p.add_argument("--n-calib-pairs", type=int, default=60)
|
| 743 |
+
p.add_argument("--no-ppc", action="store_true")
|
| 744 |
+
p.add_argument("--ddpm2-only", action="store_true")
|
| 745 |
+
p.add_argument("--ddpm6-only", action="store_true")
|
| 746 |
+
return p.parse_args()
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def main() -> None:
|
| 750 |
+
args = parse_args()
|
| 751 |
+
|
| 752 |
+
if args.ddpm2_only and args.ddpm6_only:
|
| 753 |
+
raise SystemExit("Use at most one of --ddpm2-only / --ddpm6-only.")
|
| 754 |
+
|
| 755 |
+
out_dir = Path(args.output_dir).resolve()
|
| 756 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 757 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 758 |
+
print(f"Device: {device}")
|
| 759 |
+
print(f"Output directory: {out_dir}")
|
| 760 |
+
|
| 761 |
+
data2 = Path(args.data_2param)
|
| 762 |
+
data6 = Path(args.data_6param)
|
| 763 |
+
|
| 764 |
+
imgs2, lab2 = ec.load_split(data2, args.split)
|
| 765 |
+
imgs6, lab6 = ec.load_split(data6, args.split)
|
| 766 |
+
|
| 767 |
+
n = min(len(lab2), len(lab6))
|
| 768 |
+
anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
|
| 769 |
+
print(f"Anchor indices: {anchor_ix.tolist()}")
|
| 770 |
+
|
| 771 |
+
lo_tail, hi_tail = tail_lhs_bounds(data6)
|
| 772 |
+
print(f"LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
|
| 773 |
+
|
| 774 |
+
mean2, std2 = ec.load_label_stats(data2)
|
| 775 |
+
mean6, std6 = ec.load_label_stats(data6)
|
| 776 |
+
|
| 777 |
+
if not args.ddpm6_only:
|
| 778 |
+
print("\n" + "═" * 60)
|
| 779 |
+
print(">>> DDPM-2 (corrected posteriors, six anchors)")
|
| 780 |
+
print("═" * 60)
|
| 781 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 782 |
+
args_json_2 = args.bundle_2param / "args.json"
|
| 783 |
+
model2, cfg2 = load_model(args_json_2, ck2, device)
|
| 784 |
+
|
| 785 |
+
if args.sigma_pk is not None:
|
| 786 |
+
sigma2 = args.sigma_pk
|
| 787 |
+
print(f" Using user-supplied sigma_pk = {sigma2:.4f}")
|
| 788 |
+
else:
|
| 789 |
+
sigma2 = calibrate_sigma_pk(
|
| 790 |
+
data2, model2, mean2, std2,
|
| 791 |
+
bool(cfg2.get("normalize_labels", True)),
|
| 792 |
+
int(imgs2.shape[-2]), int(imgs2.shape[-1]),
|
| 793 |
+
device, args.ddim_steps, args.n_calib_pairs,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
run_model(
|
| 797 |
+
out_dir,
|
| 798 |
+
model_name="DDPM-2",
|
| 799 |
+
model=model2,
|
| 800 |
+
cfg=cfg2,
|
| 801 |
+
images=imgs2,
|
| 802 |
+
labels=lab2,
|
| 803 |
+
lab_mean=mean2,
|
| 804 |
+
lab_std=std2,
|
| 805 |
+
anchor_ix=anchor_ix,
|
| 806 |
+
grid=args.grid,
|
| 807 |
+
ddim_steps=args.ddim_steps,
|
| 808 |
+
batch_sz=args.batch_size,
|
| 809 |
+
n_ddpm_samples=args.n_ddpm_samples,
|
| 810 |
+
n_marg_samples=args.n_marg_samples,
|
| 811 |
+
sigma_pk=sigma2,
|
| 812 |
+
do_ppc=not args.no_ppc,
|
| 813 |
+
)
|
| 814 |
+
del model2
|
| 815 |
+
gc.collect()
|
| 816 |
+
if torch.cuda.is_available():
|
| 817 |
+
torch.cuda.empty_cache()
|
| 818 |
+
|
| 819 |
+
if not args.ddpm2_only:
|
| 820 |
+
print("\n" + "═" * 60)
|
| 821 |
+
print(">>> DDPM-6 (corrected posteriors + MC marginalisation, six anchors)")
|
| 822 |
+
print("═" * 60)
|
| 823 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 824 |
+
args_json_6 = args.bundle_6param / "args.json"
|
| 825 |
+
model6, cfg6 = load_model(args_json_6, ck6, device)
|
| 826 |
+
|
| 827 |
+
if args.sigma_pk is not None:
|
| 828 |
+
sigma6 = args.sigma_pk
|
| 829 |
+
print(f" Using user-supplied sigma_pk = {sigma6:.4f}")
|
| 830 |
+
else:
|
| 831 |
+
sigma6 = calibrate_sigma_pk(
|
| 832 |
+
data6, model6, mean6, std6,
|
| 833 |
+
bool(cfg6.get("normalize_labels", True)),
|
| 834 |
+
int(imgs6.shape[-2]), int(imgs6.shape[-1]),
|
| 835 |
+
device, args.ddim_steps, args.n_calib_pairs,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
run_model(
|
| 839 |
+
out_dir,
|
| 840 |
+
model_name="DDPM-6",
|
| 841 |
+
model=model6,
|
| 842 |
+
cfg=cfg6,
|
| 843 |
+
images=imgs6,
|
| 844 |
+
labels=lab6,
|
| 845 |
+
lab_mean=mean6,
|
| 846 |
+
lab_std=std6,
|
| 847 |
+
anchor_ix=anchor_ix,
|
| 848 |
+
grid=args.grid,
|
| 849 |
+
ddim_steps=args.ddim_steps,
|
| 850 |
+
batch_sz=args.batch_size,
|
| 851 |
+
n_ddpm_samples=args.n_ddpm_samples,
|
| 852 |
+
n_marg_samples=args.n_marg_samples,
|
| 853 |
+
sigma_pk=sigma6,
|
| 854 |
+
lo_tail=lo_tail,
|
| 855 |
+
hi_tail=hi_tail,
|
| 856 |
+
do_ppc=not args.no_ppc,
|
| 857 |
+
)
|
| 858 |
+
del model6
|
| 859 |
+
gc.collect()
|
| 860 |
+
if torch.cuda.is_available():
|
| 861 |
+
torch.cuda.empty_cache()
|
| 862 |
+
|
| 863 |
+
print(f"\nAll outputs saved to: {out_dir}")
|
| 864 |
+
|
| 865 |
+
|
| 866 |
+
if __name__ == "__main__":
|
| 867 |
+
main()
|
cross_model/poster.py
ADDED
|
@@ -0,0 +1,1112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
ddpm_posterior_six_anchors_corrected.py
|
| 4 |
+
========================================
|
| 5 |
+
Corrected surrogate P(k) likelihood posteriors on (Omega_m, sigma_8)
|
| 6 |
+
for six CAMELS test anchors.
|
| 7 |
+
|
| 8 |
+
CORRECTIONS OVER THE ORIGINAL SCRIPT
|
| 9 |
+
--------------------------------------
|
| 10 |
+
(1) STOCHASTIC EMULATOR NOISE [was: 1 DDPM sample/grid point → fragmented posteriors]
|
| 11 |
+
Now: average log P(k) over `--n-pk-samples` (default 8) DDPM draws per grid
|
| 12 |
+
point, suppressing emulator variance by ~1/sqrt(N_s).
|
| 13 |
+
|
| 14 |
+
(2) CALIBRATED LIKELIHOOD NOISE SCALE [was: hard-coded sigma=0.25]
|
| 15 |
+
Now: sigma_pk is estimated from the scatter of log P(k) across repeated DDPM
|
| 16 |
+
draws at a sample of validation labels — making the noise scale physically
|
| 17 |
+
meaningful and data-driven.
|
| 18 |
+
|
| 19 |
+
(3) PROPER MARGINALIZATION OVER ASTROPHYSICAL PARAMETERS [was: fix to LHS min/max]
|
| 20 |
+
For DDPM-6, dims 2–5 are now integrated out via Monte Carlo:
|
| 21 |
+
p(Om, s8 | d) ≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i), θ_extra^i ~ Uniform(LHS)
|
| 22 |
+
replacing the incorrect conditional likelihoods p(d | Om, s8, θ_extra = fixed).
|
| 23 |
+
|
| 24 |
+
(4) GRID RESOLUTION [was: 14×14 = 196 points]
|
| 25 |
+
Now: 30×30 = 900 points (configurable via --grid).
|
| 26 |
+
|
| 27 |
+
(5) EFFECTIVE SAMPLE SIZE [was: none]
|
| 28 |
+
n_eff = 1 / Σ w_i^2 is printed for every panel. Values ≪ 30 flag collapse.
|
| 29 |
+
|
| 30 |
+
(6) CREDIBLE CONTOURS [was: raw contourf only]
|
| 31 |
+
Now: 68 % and 95 % posterior mass contours drawn explicitly on each panel.
|
| 32 |
+
|
| 33 |
+
(7) S8 DERIVED PARAMETER [was: absent]
|
| 34 |
+
S8 = sigma_8 * (Omega_m / 0.3)^0.5 reported for the posterior mean.
|
| 35 |
+
|
| 36 |
+
(8) POSTERIOR PREDICTIVE CHECK [was: absent]
|
| 37 |
+
A separate figure shows the 68/95 % posterior-predictive P(k) envelope
|
| 38 |
+
versus the observed P(k) for each anchor — a standard emulator
|
| 39 |
+
validation step.
|
| 40 |
+
|
| 41 |
+
USAGE
|
| 42 |
+
-----
|
| 43 |
+
# Both models, all corrections:
|
| 44 |
+
python ddpm_posterior_six_anchors_corrected.py
|
| 45 |
+
|
| 46 |
+
# DDPM-2 only, fast debug run:
|
| 47 |
+
python ddpm_posterior_six_anchors_corrected.py --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1
|
| 48 |
+
|
| 49 |
+
# DDPM-6 only, full quality:
|
| 50 |
+
python ddpm_posterior_six_anchors_corrected.py --ddpm6-only --grid 30 --n-pk-samples 12 --n-marg-samples 30
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
from __future__ import annotations
|
| 54 |
+
|
| 55 |
+
import argparse
|
| 56 |
+
import gc
|
| 57 |
+
import sys
|
| 58 |
+
from pathlib import Path
|
| 59 |
+
from typing import Dict, List, Optional, Tuple
|
| 60 |
+
|
| 61 |
+
import matplotlib
|
| 62 |
+
matplotlib.use("Agg")
|
| 63 |
+
import matplotlib.pyplot as plt
|
| 64 |
+
import matplotlib.ticker as mticker
|
| 65 |
+
import numpy as np
|
| 66 |
+
import torch
|
| 67 |
+
|
| 68 |
+
# ── Path setup ────────────────────────────────────────────────────────────────
|
| 69 |
+
MODELS_ROOT = Path(__file__).resolve().parent
|
| 70 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 71 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 72 |
+
sys.path.insert(0, str(CODE_6))
|
| 73 |
+
|
| 74 |
+
import evaluate_conditional as ec # noqa: E402
|
| 75 |
+
import eval_model as em # noqa: E402
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 79 |
+
# § 1 GRID CONSTRUCTION
|
| 80 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 81 |
+
|
| 82 |
+
def build_cosmo_grid(
|
| 83 |
+
grid: int,
|
| 84 |
+
om_lo: float, om_hi: float,
|
| 85 |
+
s8_lo: float, s8_hi: float,
|
| 86 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 87 |
+
"""
|
| 88 |
+
Build a regular (grid × grid) mesh over (Omega_m, sigma_8).
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
om_ax : 1-D array, shape (grid,)
|
| 93 |
+
s8_ax : 1-D array, shape (grid,)
|
| 94 |
+
grid2 : 2-D array, shape (grid^2, 2) — row-major (Omega_m varies fastest)
|
| 95 |
+
"""
|
| 96 |
+
om_ax = np.linspace(om_lo, om_hi, grid, dtype=np.float32)
|
| 97 |
+
s8_ax = np.linspace(s8_lo, s8_hi, grid, dtype=np.float32)
|
| 98 |
+
OG, SG = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 99 |
+
grid2 = np.stack([OG.ravel(), SG.ravel()], axis=1).astype(np.float32)
|
| 100 |
+
return om_ax, s8_ax, grid2
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def build_full_grid(
|
| 104 |
+
labels_ref: np.ndarray,
|
| 105 |
+
grid: int,
|
| 106 |
+
tail: Optional[np.ndarray],
|
| 107 |
+
lab_dim: int,
|
| 108 |
+
pad_frac: float = 0.02,
|
| 109 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 110 |
+
"""
|
| 111 |
+
Build the full label matrix for the posterior grid.
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
labels_ref : reference labels from which (Om, s8) range is inferred
|
| 116 |
+
grid : grid points per axis
|
| 117 |
+
tail : fixed values for dims 2–5 (None for DDPM-2)
|
| 118 |
+
lab_dim : total label dimension (2 or 6)
|
| 119 |
+
pad_frac : fractional padding beyond data range
|
| 120 |
+
|
| 121 |
+
Returns
|
| 122 |
+
-------
|
| 123 |
+
full : (grid^2, lab_dim) float32
|
| 124 |
+
om_ax : (grid,) float32
|
| 125 |
+
s8_ax : (grid,) float32
|
| 126 |
+
"""
|
| 127 |
+
lo0, hi0 = float(labels_ref[:, 0].min()), float(labels_ref[:, 0].max())
|
| 128 |
+
lo1, hi1 = float(labels_ref[:, 1].min()), float(labels_ref[:, 1].max())
|
| 129 |
+
p0 = pad_frac * (hi0 - lo0 + 1e-12)
|
| 130 |
+
p1 = pad_frac * (hi1 - lo1 + 1e-12)
|
| 131 |
+
|
| 132 |
+
om_ax, s8_ax, grid2 = build_cosmo_grid(grid, lo0 - p0, hi0 + p0,
|
| 133 |
+
lo1 - p1, hi1 + p1)
|
| 134 |
+
ngrid = grid2.shape[0]
|
| 135 |
+
|
| 136 |
+
full = np.zeros((ngrid, lab_dim), dtype=np.float32)
|
| 137 |
+
full[:, 0] = grid2[:, 0]
|
| 138 |
+
full[:, 1] = grid2[:, 1]
|
| 139 |
+
if tail is not None:
|
| 140 |
+
assert tail.shape == (4,), f"tail must be shape (4,), got {tail.shape}"
|
| 141 |
+
full[:, 2:6] = tail[np.newaxis, :]
|
| 142 |
+
|
| 143 |
+
return full, om_ax, s8_ax
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 147 |
+
# § 2 LHS BOUNDS
|
| 148 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 149 |
+
|
| 150 |
+
def _train_label_path(data_dir: Path) -> Path:
|
| 151 |
+
for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
|
| 152 |
+
p = data_dir / name
|
| 153 |
+
if p.is_file():
|
| 154 |
+
return p
|
| 155 |
+
raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
|
| 159 |
+
"""Min/max of LHS training labels for dims 2–5."""
|
| 160 |
+
L = np.load(_train_label_path(data_dir))
|
| 161 |
+
if L.shape[1] < 6:
|
| 162 |
+
raise ValueError(f"Expected ≥6 label columns, got shape {L.shape}")
|
| 163 |
+
lo = L[:, 2:6].min(axis=0).astype(np.float32)
|
| 164 |
+
hi = L[:, 2:6].max(axis=0).astype(np.float32)
|
| 165 |
+
return lo, hi
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 169 |
+
# § 3 OBSERVED LOG P(k)
|
| 170 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 171 |
+
|
| 172 |
+
def log_pk_observed(
|
| 173 |
+
obs_image: np.ndarray,
|
| 174 |
+
box_size: float = 25.0,
|
| 175 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 176 |
+
"""
|
| 177 |
+
Compute log10 P(k) of the *observed* HI map, after converting
|
| 178 |
+
from [0,1] pixel scale to log10(N_HI).
|
| 179 |
+
|
| 180 |
+
Returns
|
| 181 |
+
-------
|
| 182 |
+
dk : k-mode array (n_bins,)
|
| 183 |
+
log_pd : log power spectrum of observed map (n_bins,), valid-modes only
|
| 184 |
+
valid : boolean mask selecting non-zero k-modes
|
| 185 |
+
"""
|
| 186 |
+
# images_01_to_log_nhi expects shape (..., H, W) or (H, W)
|
| 187 |
+
log_nhi = em.images01_to_log_nhi(obs_image[np.newaxis]) # (1, H, W)
|
| 188 |
+
npix = obs_image.shape[-1]
|
| 189 |
+
dl = box_size / npix
|
| 190 |
+
dk, pk = ec.PowerSpectrum(log_nhi[0], N=npix, dl=dl)
|
| 191 |
+
valid = dk > 0
|
| 192 |
+
log_pd = np.log(pk[valid] + 1e-30)
|
| 193 |
+
return dk, log_pd, valid
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 197 |
+
# § 4 SIGMA_PK CALIBRATION (Correction #2)
|
| 198 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 199 |
+
|
| 200 |
+
def calibrate_sigma_pk(
|
| 201 |
+
model: torch.nn.Module,
|
| 202 |
+
images_val: np.ndarray,
|
| 203 |
+
labels_val: np.ndarray,
|
| 204 |
+
lab_mean: np.ndarray,
|
| 205 |
+
lab_std: np.ndarray,
|
| 206 |
+
normalize: bool,
|
| 207 |
+
device: torch.device,
|
| 208 |
+
box_size: float = 25.0,
|
| 209 |
+
ddim_steps: int = 50,
|
| 210 |
+
n_pairs: int = 30,
|
| 211 |
+
seed: int = 0,
|
| 212 |
+
) -> float:
|
| 213 |
+
"""
|
| 214 |
+
Estimate the log-P(k) noise scale from the *aleatoric* variance of the
|
| 215 |
+
DDPM emulator at fixed labels.
|
| 216 |
+
|
| 217 |
+
For n_pairs validation images we draw two independent DDPM samples and
|
| 218 |
+
compute std(log Pk_a - log Pk_b) / sqrt(2), then take the median.
|
| 219 |
+
|
| 220 |
+
This gives a physically motivated sigma_pk that replaces the hard-coded 0.25.
|
| 221 |
+
"""
|
| 222 |
+
rng = np.random.default_rng(seed)
|
| 223 |
+
n_val = min(n_pairs, len(labels_val))
|
| 224 |
+
idx = rng.choice(len(labels_val), size=n_val, replace=False)
|
| 225 |
+
labs = labels_val[idx].astype(np.float32) # (n_val, lab_dim)
|
| 226 |
+
|
| 227 |
+
H, W = int(images_val.shape[-2]), int(images_val.shape[-1])
|
| 228 |
+
|
| 229 |
+
sigmas = []
|
| 230 |
+
for i in range(n_val):
|
| 231 |
+
lab_i = labs[i:i+1] # (1, lab_dim)
|
| 232 |
+
pair = np.concatenate([lab_i, lab_i], axis=0) # (2, lab_dim)
|
| 233 |
+
|
| 234 |
+
imgs = em.sample_batch(
|
| 235 |
+
model, pair, lab_mean, lab_std, normalize,
|
| 236 |
+
H, W, device, ddim_steps, False,
|
| 237 |
+
) # (2, H, W) in [0, 1]
|
| 238 |
+
|
| 239 |
+
dk, log_pk_a, valid = log_pk_observed(imgs[0], box_size)
|
| 240 |
+
_, log_pk_b, _ = log_pk_observed(imgs[1], box_size)
|
| 241 |
+
diff = log_pk_a - log_pk_b
|
| 242 |
+
# sigma of a single draw = std(diff) / sqrt(2)
|
| 243 |
+
sigmas.append(float(np.std(diff) / np.sqrt(2.0)))
|
| 244 |
+
|
| 245 |
+
sigma_cal = float(np.median(sigmas))
|
| 246 |
+
print(
|
| 247 |
+
f" [calibrate_sigma_pk] n_pairs={n_val} "
|
| 248 |
+
f"median σ_pk={sigma_cal:.4f} "
|
| 249 |
+
f"(was hard-coded 0.25)"
|
| 250 |
+
)
|
| 251 |
+
return max(sigma_cal, 0.01) # safety floor
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 255 |
+
# § 5 AVERAGED LOG-LIKELIHOOD (Correction #1)
|
| 256 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 257 |
+
|
| 258 |
+
def averaged_log_likelihood(
|
| 259 |
+
obs_image: np.ndarray,
|
| 260 |
+
full: np.ndarray,
|
| 261 |
+
lab_mean: np.ndarray,
|
| 262 |
+
lab_std: np.ndarray,
|
| 263 |
+
normalize: bool,
|
| 264 |
+
model: torch.nn.Module,
|
| 265 |
+
device: torch.device,
|
| 266 |
+
H: int,
|
| 267 |
+
W: int,
|
| 268 |
+
box_size: float,
|
| 269 |
+
ddim_steps: int,
|
| 270 |
+
batch_sz: int,
|
| 271 |
+
n_pk_samples: int,
|
| 272 |
+
sigma_pk: float,
|
| 273 |
+
) -> np.ndarray:
|
| 274 |
+
"""
|
| 275 |
+
Compute the Gaussian log-likelihood for every grid point in `full`,
|
| 276 |
+
averaging over `n_pk_samples` independent DDPM draws to suppress
|
| 277 |
+
emulator stochasticity.
|
| 278 |
+
|
| 279 |
+
Parameters
|
| 280 |
+
----------
|
| 281 |
+
full : (ngrid, lab_dim) array of grid labels
|
| 282 |
+
n_pk_samples : number of DDPM draws to average (≥8 recommended)
|
| 283 |
+
sigma_pk : calibrated log-P(k) noise scale
|
| 284 |
+
|
| 285 |
+
Returns
|
| 286 |
+
-------
|
| 287 |
+
log_w : (ngrid,) unnormalised log-posterior weights
|
| 288 |
+
"""
|
| 289 |
+
_, log_pd, valid = log_pk_observed(obs_image, box_size)
|
| 290 |
+
ngrid = full.shape[0]
|
| 291 |
+
|
| 292 |
+
# Accumulate sum of log P(k) over n_pk_samples draws
|
| 293 |
+
sum_log_pg = np.zeros((ngrid, int(valid.sum())), dtype=np.float64)
|
| 294 |
+
|
| 295 |
+
for s in range(n_pk_samples):
|
| 296 |
+
all_pk = []
|
| 297 |
+
for j0 in range(0, ngrid, batch_sz):
|
| 298 |
+
chunk = full[j0: j0 + batch_sz]
|
| 299 |
+
imgs = em.sample_batch(
|
| 300 |
+
model, chunk, lab_mean, lab_std, normalize,
|
| 301 |
+
H, W, device, ddim_steps, False,
|
| 302 |
+
) # (chunk_sz, H, W)
|
| 303 |
+
_, pks = em.per_map_power_spectra_log(imgs, box_size)
|
| 304 |
+
# pks shape: (chunk_sz, n_bins); select valid bins
|
| 305 |
+
all_pk.append(pks[:, valid])
|
| 306 |
+
|
| 307 |
+
pk_all = np.concatenate(all_pk, axis=0) # (ngrid, n_valid)
|
| 308 |
+
sum_log_pg += np.log(pk_all + 1e-30)
|
| 309 |
+
|
| 310 |
+
mean_log_pg = sum_log_pg / n_pk_samples # (ngrid, n_valid)
|
| 311 |
+
|
| 312 |
+
# Gaussian log-likelihood: -0.5 * Σ_k [(log Pd - log Pg)^2] / sigma^2
|
| 313 |
+
mse = np.mean((log_pd[np.newaxis, :] - mean_log_pg) ** 2, axis=1)
|
| 314 |
+
log_w = -mse / (2.0 * sigma_pk ** 2)
|
| 315 |
+
|
| 316 |
+
return log_w.astype(np.float64)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 320 |
+
# § 6 POSTERIOR WEIGHT COMPUTATION
|
| 321 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 322 |
+
|
| 323 |
+
def posterior_weights_ddpm2(
|
| 324 |
+
obs_image: np.ndarray,
|
| 325 |
+
labels_ref: np.ndarray,
|
| 326 |
+
lab_mean: np.ndarray,
|
| 327 |
+
lab_std: np.ndarray,
|
| 328 |
+
normalize: bool,
|
| 329 |
+
model: torch.nn.Module,
|
| 330 |
+
device: torch.device,
|
| 331 |
+
grid: int,
|
| 332 |
+
batch_sz: int,
|
| 333 |
+
ddim_steps: int,
|
| 334 |
+
n_pk_samples: int,
|
| 335 |
+
sigma_pk: float,
|
| 336 |
+
box_size: float = 25.0,
|
| 337 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 338 |
+
"""
|
| 339 |
+
Compute the DDPM-2 surrogate posterior on (Omega_m, sigma_8).
|
| 340 |
+
Returns (Wmap, OM, S8) with Wmap shaped (grid, grid).
|
| 341 |
+
"""
|
| 342 |
+
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
|
| 343 |
+
full, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
|
| 344 |
+
|
| 345 |
+
log_w = averaged_log_likelihood(
|
| 346 |
+
obs_image, full, lab_mean, lab_std, normalize, model, device,
|
| 347 |
+
H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
log_w -= log_w.max() # numerical stability
|
| 351 |
+
w = np.exp(log_w).reshape(grid, grid)
|
| 352 |
+
w /= w.sum()
|
| 353 |
+
|
| 354 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 355 |
+
return w, OM, S8
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def posterior_weights_ddpm6_marginalised(
|
| 359 |
+
obs_image: np.ndarray,
|
| 360 |
+
labels_ref: np.ndarray,
|
| 361 |
+
lab_mean: np.ndarray,
|
| 362 |
+
lab_std: np.ndarray,
|
| 363 |
+
normalize: bool,
|
| 364 |
+
model: torch.nn.Module,
|
| 365 |
+
device: torch.device,
|
| 366 |
+
lo_tail: np.ndarray,
|
| 367 |
+
hi_tail: np.ndarray,
|
| 368 |
+
grid: int,
|
| 369 |
+
batch_sz: int,
|
| 370 |
+
ddim_steps: int,
|
| 371 |
+
n_pk_samples: int,
|
| 372 |
+
n_marg_samples: int,
|
| 373 |
+
sigma_pk: float,
|
| 374 |
+
box_size: float = 25.0,
|
| 375 |
+
seed: int = 1,
|
| 376 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 377 |
+
"""
|
| 378 |
+
Compute the DDPM-6 *marginal* posterior on (Omega_m, sigma_8) by
|
| 379 |
+
Monte Carlo integration over the astrophysical nuisance parameters:
|
| 380 |
+
|
| 381 |
+
p(Om, s8 | d) ∝ ∫ L(d | Om, s8, θ_extra) π(θ_extra) dθ_extra
|
| 382 |
+
≈ (1/N) Σ_i L(d | Om, s8, θ_extra^i)
|
| 383 |
+
|
| 384 |
+
where θ_extra^i ~ Uniform(LHS range for dims 2-5).
|
| 385 |
+
|
| 386 |
+
This replaces the incorrect approach of fixing dims 2-5 to their
|
| 387 |
+
LHS extrema, which computes a *conditional* likelihood, not a marginal.
|
| 388 |
+
|
| 389 |
+
Parameters
|
| 390 |
+
----------
|
| 391 |
+
n_marg_samples : number of MC draws for astrophysical parameter integration
|
| 392 |
+
(≥20 recommended; more = smoother but slower)
|
| 393 |
+
"""
|
| 394 |
+
rng = np.random.default_rng(seed)
|
| 395 |
+
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
|
| 396 |
+
|
| 397 |
+
# Draw astrophysical parameter samples from their uniform prior over LHS
|
| 398 |
+
theta_extra_draws = rng.uniform(
|
| 399 |
+
lo_tail, hi_tail,
|
| 400 |
+
size=(n_marg_samples, 4),
|
| 401 |
+
).astype(np.float32)
|
| 402 |
+
|
| 403 |
+
_, om_ax, s8_ax = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
|
| 404 |
+
full_cosmo, _, _ = build_full_grid(labels_ref, grid, tail=None, lab_dim=2)
|
| 405 |
+
ngrid = full_cosmo.shape[0]
|
| 406 |
+
|
| 407 |
+
# log-sum-exp accumulator over marginalisation samples
|
| 408 |
+
log_w_accum = np.full(ngrid, -np.inf, dtype=np.float64)
|
| 409 |
+
|
| 410 |
+
for m_idx, theta_extra in enumerate(theta_extra_draws):
|
| 411 |
+
# Assemble 6D label grid with this draw of astrophysical params
|
| 412 |
+
full_6d = np.zeros((ngrid, 6), dtype=np.float32)
|
| 413 |
+
full_6d[:, :2] = full_cosmo[:, :2]
|
| 414 |
+
full_6d[:, 2:6] = theta_extra[np.newaxis, :]
|
| 415 |
+
|
| 416 |
+
log_w_m = averaged_log_likelihood(
|
| 417 |
+
obs_image, full_6d, lab_mean, lab_std, normalize, model, device,
|
| 418 |
+
H, W, box_size, ddim_steps, batch_sz, n_pk_samples, sigma_pk,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
# log-sum-exp: accumulate log Σ L_i → after loop divide by N_marg
|
| 422 |
+
log_w_accum = np.logaddexp(log_w_accum, log_w_m)
|
| 423 |
+
|
| 424 |
+
if (m_idx + 1) % 5 == 0 or (m_idx + 1) == n_marg_samples:
|
| 425 |
+
print(f" marginalisation sample {m_idx+1}/{n_marg_samples} done")
|
| 426 |
+
|
| 427 |
+
# Subtract log(N_marg) to convert sum → mean, then normalise
|
| 428 |
+
log_w_accum -= np.log(n_marg_samples)
|
| 429 |
+
log_w_accum -= log_w_accum.max()
|
| 430 |
+
w = np.exp(log_w_accum).reshape(grid, grid)
|
| 431 |
+
w /= w.sum()
|
| 432 |
+
|
| 433 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 434 |
+
return w, OM, S8
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 438 |
+
# § 7 POSTERIOR DIAGNOSTICS
|
| 439 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 440 |
+
|
| 441 |
+
def effective_sample_size(w: np.ndarray) -> float:
|
| 442 |
+
"""n_eff = 1 / Σ w_i^2. Values < 30 indicate posterior collapse."""
|
| 443 |
+
w_flat = w.ravel() / w.sum()
|
| 444 |
+
return float(1.0 / (w_flat ** 2).sum())
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def credible_levels(
|
| 448 |
+
w: np.ndarray,
|
| 449 |
+
levels: Tuple[float, ...] = (0.68, 0.95),
|
| 450 |
+
) -> List[float]:
|
| 451 |
+
"""
|
| 452 |
+
Find the weight threshold c such that the region {w ≥ c} contains
|
| 453 |
+
exactly `level` of the total probability mass.
|
| 454 |
+
|
| 455 |
+
Returns a list of thresholds, one per level (descending).
|
| 456 |
+
"""
|
| 457 |
+
w_flat = w.ravel()
|
| 458 |
+
sorted_w = np.sort(w_flat)[::-1]
|
| 459 |
+
cumsum = np.cumsum(sorted_w)
|
| 460 |
+
thresholds = []
|
| 461 |
+
for level in levels:
|
| 462 |
+
idx = np.searchsorted(cumsum, level * w_flat.sum())
|
| 463 |
+
idx = min(idx, len(sorted_w) - 1)
|
| 464 |
+
thresholds.append(float(sorted_w[idx]))
|
| 465 |
+
return thresholds
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def posterior_summary(
|
| 469 |
+
w: np.ndarray,
|
| 470 |
+
OM: np.ndarray,
|
| 471 |
+
S8: np.ndarray,
|
| 472 |
+
) -> Dict:
|
| 473 |
+
"""
|
| 474 |
+
Return a dict with posterior mean, std, and S8 derived parameter.
|
| 475 |
+
"""
|
| 476 |
+
w_norm = w / w.sum()
|
| 477 |
+
mom = float((w_norm * OM).sum())
|
| 478 |
+
ms8 = float((w_norm * S8).sum())
|
| 479 |
+
vom = float((w_norm * (OM - mom) ** 2).sum()) ** 0.5
|
| 480 |
+
vs8 = float((w_norm * (S8 - ms8) ** 2).sum()) ** 0.5
|
| 481 |
+
mS8 = ms8 * (mom / 0.3) ** 0.5
|
| 482 |
+
n_eff = effective_sample_size(w_norm)
|
| 483 |
+
return dict(om_mean=mom, om_std=vom, s8_mean=ms8, s8_std=vs8,
|
| 484 |
+
S8_mean=mS8, n_eff=n_eff)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 488 |
+
# § 8 PLOTTING
|
| 489 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 490 |
+
|
| 491 |
+
def plot_posterior_panel(
|
| 492 |
+
ax: plt.Axes,
|
| 493 |
+
w: np.ndarray,
|
| 494 |
+
OM: np.ndarray,
|
| 495 |
+
S8: np.ndarray,
|
| 496 |
+
true_om: float,
|
| 497 |
+
true_s8: float,
|
| 498 |
+
title: str,
|
| 499 |
+
summary: Optional[Dict] = None,
|
| 500 |
+
) -> None:
|
| 501 |
+
"""
|
| 502 |
+
Plot one posterior panel with:
|
| 503 |
+
• filled colour map of posterior weights
|
| 504 |
+
• 68 % and 95 % credible contours
|
| 505 |
+
• true parameter location (red ×)
|
| 506 |
+
• posterior mean (black +)
|
| 507 |
+
• n_eff and posterior-mean S8 as text annotation
|
| 508 |
+
"""
|
| 509 |
+
# ── colour map ────────────────────────────────────────────────────────────
|
| 510 |
+
cf = ax.contourf(OM, S8, w, levels=14, cmap="Blues")
|
| 511 |
+
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
|
| 512 |
+
|
| 513 |
+
# ── credible contours ─────────────────────────────────────────────────────
|
| 514 |
+
try:
|
| 515 |
+
thresh_68, thresh_95 = credible_levels(w, levels=(0.68, 0.95))
|
| 516 |
+
ax.contour(OM, S8, w, levels=[thresh_95, thresh_68],
|
| 517 |
+
colors=["#e07b39", "#c0392b"],
|
| 518 |
+
linewidths=[1.2, 1.8], linestyles=["--", "-"])
|
| 519 |
+
# Proxy artists for legend
|
| 520 |
+
from matplotlib.lines import Line2D
|
| 521 |
+
ax.legend(
|
| 522 |
+
handles=[
|
| 523 |
+
Line2D([], [], color="#c0392b", lw=1.8, label="68 % CR"),
|
| 524 |
+
Line2D([], [], color="#e07b39", lw=1.2, ls="--", label="95 % CR"),
|
| 525 |
+
Line2D([], [], marker="x", color="r", ls="", ms=8, label="true"),
|
| 526 |
+
Line2D([], [], marker="+", color="k", ls="", ms=8, label="post. mean"),
|
| 527 |
+
],
|
| 528 |
+
fontsize=6.5, loc="upper right",
|
| 529 |
+
)
|
| 530 |
+
except Exception:
|
| 531 |
+
ax.legend(fontsize=6.5)
|
| 532 |
+
|
| 533 |
+
# ── markers ───────────────────────────────────────────────────────────────
|
| 534 |
+
if summary:
|
| 535 |
+
ax.scatter(summary["om_mean"], summary["s8_mean"],
|
| 536 |
+
s=60, c="k", marker="+", zorder=7)
|
| 537 |
+
ax.scatter(true_om, true_s8, s=60, c="r", marker="x", zorder=7)
|
| 538 |
+
|
| 539 |
+
# ── S8 degeneracy line (for visual reference) ─────────────────────────────
|
| 540 |
+
om_arr = np.linspace(float(OM.min()), float(OM.max()), 200)
|
| 541 |
+
if summary:
|
| 542 |
+
S8_val = summary["s8_mean"] * (summary["om_mean"] / 0.3) ** 0.5
|
| 543 |
+
s8_degen = S8_val / (om_arr / 0.3) ** 0.5
|
| 544 |
+
mask = (s8_degen >= float(S8.min())) & (s8_degen <= float(S8.max()))
|
| 545 |
+
if mask.any():
|
| 546 |
+
ax.plot(om_arr[mask], s8_degen[mask], "k:", lw=0.8, alpha=0.5,
|
| 547 |
+
label=f"$S_8$={S8_val:.3f}")
|
| 548 |
+
|
| 549 |
+
# ── labels and annotation ─────────────────────────────────────────────────
|
| 550 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=9)
|
| 551 |
+
ax.set_ylabel(r"$\sigma_8$", fontsize=9)
|
| 552 |
+
ax.set_title(title, fontsize=8)
|
| 553 |
+
|
| 554 |
+
if summary:
|
| 555 |
+
info = (
|
| 556 |
+
f"$n_\\mathrm{{eff}}$={summary['n_eff']:.0f}\n"
|
| 557 |
+
f"$S_8$={summary['S8_mean']:.3f}\n"
|
| 558 |
+
f"$\\Omega_m$={summary['om_mean']:.3f}±{summary['om_std']:.3f}\n"
|
| 559 |
+
f"$\\sigma_8$={summary['s8_mean']:.3f}±{summary['s8_std']:.3f}"
|
| 560 |
+
)
|
| 561 |
+
ax.text(0.02, 0.98, info, transform=ax.transAxes,
|
| 562 |
+
fontsize=6.5, va="top", color="#222",
|
| 563 |
+
bbox=dict(fc="white", ec="none", alpha=0.7, pad=1.5))
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def make_posterior_figure(
|
| 567 |
+
panels: List[Dict],
|
| 568 |
+
suptitle: str,
|
| 569 |
+
out_path: Path,
|
| 570 |
+
) -> None:
|
| 571 |
+
"""
|
| 572 |
+
Create a 2×3 grid of posterior panels and save to `out_path`.
|
| 573 |
+
|
| 574 |
+
Each element of `panels` must be a dict with keys:
|
| 575 |
+
w, OM, S8, true_om, true_s8, title, summary
|
| 576 |
+
"""
|
| 577 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 9.5), squeeze=False)
|
| 578 |
+
for k, p in enumerate(panels):
|
| 579 |
+
r, c = divmod(k, 3)
|
| 580 |
+
plot_posterior_panel(
|
| 581 |
+
axes[r, c],
|
| 582 |
+
p["w"], p["OM"], p["S8"],
|
| 583 |
+
p["true_om"], p["true_s8"],
|
| 584 |
+
p["title"], p.get("summary"),
|
| 585 |
+
)
|
| 586 |
+
plt.suptitle(suptitle, fontsize=11, y=0.998)
|
| 587 |
+
plt.tight_layout(rect=(0, 0, 1, 0.97))
|
| 588 |
+
fig.savefig(out_path, dpi=170, bbox_inches="tight")
|
| 589 |
+
plt.close(fig)
|
| 590 |
+
print(f" Saved → {out_path}")
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 594 |
+
# § 9 POSTERIOR PREDICTIVE CHECK (Correction #8)
|
| 595 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 596 |
+
|
| 597 |
+
def posterior_predictive_check(
|
| 598 |
+
obs_image: np.ndarray,
|
| 599 |
+
w: np.ndarray,
|
| 600 |
+
OM: np.ndarray,
|
| 601 |
+
S8: np.ndarray,
|
| 602 |
+
model: torch.nn.Module,
|
| 603 |
+
lab_mean: np.ndarray,
|
| 604 |
+
lab_std: np.ndarray,
|
| 605 |
+
normalize: bool,
|
| 606 |
+
device: torch.device,
|
| 607 |
+
ddim_steps: int,
|
| 608 |
+
box_size: float = 25.0,
|
| 609 |
+
n_draws: int = 40,
|
| 610 |
+
seed: int = 42,
|
| 611 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 612 |
+
"""
|
| 613 |
+
Draw `n_draws` parameter samples from the posterior and generate DDPM
|
| 614 |
+
images; return the stacked log P(k) array for envelope plotting.
|
| 615 |
+
"""
|
| 616 |
+
rng = np.random.default_rng(seed)
|
| 617 |
+
w_flat = w.ravel() / w.sum()
|
| 618 |
+
idx = rng.choice(len(w_flat), size=n_draws, replace=True, p=w_flat)
|
| 619 |
+
|
| 620 |
+
om_flat = OM.ravel()
|
| 621 |
+
s8_flat = S8.ravel()
|
| 622 |
+
labs = np.stack([om_flat[idx], s8_flat[idx]], axis=1).astype(np.float32)
|
| 623 |
+
|
| 624 |
+
H, W = int(obs_image.shape[-2]), int(obs_image.shape[-1])
|
| 625 |
+
imgs = em.sample_batch(
|
| 626 |
+
model, labs, lab_mean, lab_std, normalize,
|
| 627 |
+
H, W, device, ddim_steps, False,
|
| 628 |
+
) # (n_draws, H, W)
|
| 629 |
+
|
| 630 |
+
_, pks = em.per_map_power_spectra_log(imgs, box_size) # (n_draws, n_bins)
|
| 631 |
+
log_pks = np.log(pks + 1e-30)
|
| 632 |
+
|
| 633 |
+
# Observed
|
| 634 |
+
dk, log_pd, valid = log_pk_observed(obs_image, box_size)
|
| 635 |
+
return dk[valid], log_pd, log_pks[:, valid]
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def plot_ppc_panel(
|
| 639 |
+
ax: plt.Axes,
|
| 640 |
+
dk_valid: np.ndarray,
|
| 641 |
+
log_pd: np.ndarray,
|
| 642 |
+
log_pks: np.ndarray,
|
| 643 |
+
title: str,
|
| 644 |
+
) -> None:
|
| 645 |
+
lo95 = np.percentile(log_pks, 2.5, axis=0)
|
| 646 |
+
hi95 = np.percentile(log_pks, 97.5, axis=0)
|
| 647 |
+
lo68 = np.percentile(log_pks, 16.0, axis=0)
|
| 648 |
+
hi68 = np.percentile(log_pks, 84.0, axis=0)
|
| 649 |
+
med = np.median(log_pks, axis=0)
|
| 650 |
+
|
| 651 |
+
ax.fill_between(dk_valid, lo95, hi95,
|
| 652 |
+
alpha=0.20, color="steelblue", label="95 % PPC")
|
| 653 |
+
ax.fill_between(dk_valid, lo68, hi68,
|
| 654 |
+
alpha=0.40, color="steelblue", label="68 % PPC")
|
| 655 |
+
ax.plot(dk_valid, med, "b-", lw=1.4, label="PPC median")
|
| 656 |
+
ax.plot(dk_valid, log_pd, "r-", lw=1.6, label="Observed")
|
| 657 |
+
|
| 658 |
+
ax.set_xlabel(r"$k$ [h/Mpc]", fontsize=8)
|
| 659 |
+
ax.set_ylabel(r"$\log\,P_\mathrm{HI}(k)$", fontsize=8)
|
| 660 |
+
ax.set_title(title, fontsize=8)
|
| 661 |
+
ax.legend(fontsize=6.5)
|
| 662 |
+
ax.grid(alpha=0.3, lw=0.5)
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def make_ppc_figure(
|
| 666 |
+
ppc_data: List[Dict],
|
| 667 |
+
suptitle: str,
|
| 668 |
+
out_path: Path,
|
| 669 |
+
) -> None:
|
| 670 |
+
fig, axes = plt.subplots(2, 3, figsize=(15, 8), squeeze=False)
|
| 671 |
+
for k, d in enumerate(ppc_data):
|
| 672 |
+
r, c = divmod(k, 3)
|
| 673 |
+
plot_ppc_panel(axes[r, c], d["dk"], d["log_pd"],
|
| 674 |
+
d["log_pks"], d["title"])
|
| 675 |
+
plt.suptitle(suptitle, fontsize=11, y=0.998)
|
| 676 |
+
plt.tight_layout(rect=(0, 0, 1, 0.97))
|
| 677 |
+
fig.savefig(out_path, dpi=150, bbox_inches="tight")
|
| 678 |
+
plt.close(fig)
|
| 679 |
+
print(f" Saved → {out_path}")
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 683 |
+
# § 10 MODEL LOADING
|
| 684 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 685 |
+
|
| 686 |
+
def load_model(
|
| 687 |
+
args_json: Path,
|
| 688 |
+
ckpt: Path,
|
| 689 |
+
device: torch.device,
|
| 690 |
+
) -> Tuple[torch.nn.Module, Dict]:
|
| 691 |
+
cfg = ec.load_training_config(str(args_json))
|
| 692 |
+
model = ec.build_model(cfg, device)
|
| 693 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 694 |
+
model.eval()
|
| 695 |
+
return model, cfg
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 699 |
+
# § 11 HIGH-LEVEL RUNNERS
|
| 700 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 701 |
+
|
| 702 |
+
def run_ddpm2(
|
| 703 |
+
out_dir: Path,
|
| 704 |
+
imgs: np.ndarray,
|
| 705 |
+
labs: np.ndarray,
|
| 706 |
+
lab_mean: np.ndarray,
|
| 707 |
+
lab_std: np.ndarray,
|
| 708 |
+
cfg: Dict,
|
| 709 |
+
model: torch.nn.Module,
|
| 710 |
+
device: torch.device,
|
| 711 |
+
anchor_ix: np.ndarray,
|
| 712 |
+
grid: int,
|
| 713 |
+
ddim_steps: int,
|
| 714 |
+
batch_sz: int,
|
| 715 |
+
n_pk_samples: int,
|
| 716 |
+
sigma_pk: float,
|
| 717 |
+
do_ppc: bool = True,
|
| 718 |
+
) -> None:
|
| 719 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 720 |
+
panels = []
|
| 721 |
+
ppc_data = []
|
| 722 |
+
|
| 723 |
+
for k, ix in enumerate(anchor_ix.ravel()):
|
| 724 |
+
ix = int(ix)
|
| 725 |
+
obs = imgs[ix]
|
| 726 |
+
lab_t = labs[ix].astype(np.float32)
|
| 727 |
+
tom, ts8 = float(lab_t[0]), float(lab_t[1])
|
| 728 |
+
|
| 729 |
+
print(f" [DDPM-2] anchor {k+1}/6 ix={ix} "
|
| 730 |
+
f"Ωm={tom:.3f} σ8={ts8:.3f}")
|
| 731 |
+
|
| 732 |
+
w, OM, S8 = posterior_weights_ddpm2(
|
| 733 |
+
obs, labs, lab_mean, lab_std, normalize, model, device,
|
| 734 |
+
grid, batch_sz, ddim_steps, n_pk_samples, sigma_pk,
|
| 735 |
+
)
|
| 736 |
+
summ = posterior_summary(w, OM, S8)
|
| 737 |
+
print(f" n_eff={summ['n_eff']:.0f} "
|
| 738 |
+
f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} "
|
| 739 |
+
f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} "
|
| 740 |
+
f"S8={summ['S8_mean']:.3f}")
|
| 741 |
+
|
| 742 |
+
panels.append(dict(
|
| 743 |
+
w=w, OM=OM, S8=S8,
|
| 744 |
+
true_om=tom, true_s8=ts8, summary=summ,
|
| 745 |
+
title=(
|
| 746 |
+
f"test ix={ix} | "
|
| 747 |
+
r"$\Omega_m$" + f"={tom:.3f}, "
|
| 748 |
+
r"$\sigma_8$" + f"={ts8:.3f}"
|
| 749 |
+
),
|
| 750 |
+
))
|
| 751 |
+
|
| 752 |
+
if do_ppc:
|
| 753 |
+
dk_v, log_pd, log_pks = posterior_predictive_check(
|
| 754 |
+
obs, w, OM, S8, model, lab_mean, lab_std, normalize,
|
| 755 |
+
device, ddim_steps,
|
| 756 |
+
)
|
| 757 |
+
ppc_data.append(dict(
|
| 758 |
+
dk=dk_v, log_pd=log_pd, log_pks=log_pks,
|
| 759 |
+
title=f"PPC test ix={ix}",
|
| 760 |
+
))
|
| 761 |
+
|
| 762 |
+
# ── posterior figure ──────────────────────────────────────────────────────
|
| 763 |
+
make_posterior_figure(
|
| 764 |
+
panels,
|
| 765 |
+
suptitle=(
|
| 766 |
+
r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ — "
|
| 767 |
+
r"six CAMELS anchors "
|
| 768 |
+
f"[{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]"
|
| 769 |
+
),
|
| 770 |
+
out_path=out_dir / "posterior_six_anchors_ddpm2_corrected.png",
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# ── PPC figure ────────────────────────────────────────────────────────────
|
| 774 |
+
if do_ppc and ppc_data:
|
| 775 |
+
make_ppc_figure(
|
| 776 |
+
ppc_data,
|
| 777 |
+
suptitle="DDPM-2 Posterior Predictive Check — P(k) envelope vs. observed",
|
| 778 |
+
out_path=out_dir / "ppc_six_anchors_ddpm2.png",
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
|
| 782 |
+
def run_ddpm6(
|
| 783 |
+
out_dir: Path,
|
| 784 |
+
imgs: np.ndarray,
|
| 785 |
+
labs: np.ndarray,
|
| 786 |
+
lab_mean: np.ndarray,
|
| 787 |
+
lab_std: np.ndarray,
|
| 788 |
+
cfg: Dict,
|
| 789 |
+
model: torch.nn.Module,
|
| 790 |
+
device: torch.device,
|
| 791 |
+
lo_tail: np.ndarray,
|
| 792 |
+
hi_tail: np.ndarray,
|
| 793 |
+
anchor_ix: np.ndarray,
|
| 794 |
+
grid: int,
|
| 795 |
+
ddim_steps: int,
|
| 796 |
+
batch_sz: int,
|
| 797 |
+
n_pk_samples: int,
|
| 798 |
+
n_marg_samples: int,
|
| 799 |
+
sigma_pk: float,
|
| 800 |
+
do_ppc: bool = True,
|
| 801 |
+
) -> None:
|
| 802 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 803 |
+
panels = []
|
| 804 |
+
ppc_data = []
|
| 805 |
+
|
| 806 |
+
for k, ix in enumerate(anchor_ix.ravel()):
|
| 807 |
+
ix = int(ix)
|
| 808 |
+
obs = imgs[ix]
|
| 809 |
+
lab_t = labs[ix].astype(np.float32)
|
| 810 |
+
tom, ts8 = float(lab_t[0]), float(lab_t[1])
|
| 811 |
+
|
| 812 |
+
print(f" [DDPM-6] anchor {k+1}/6 ix={ix} "
|
| 813 |
+
f"Ωm={tom:.3f} σ8={ts8:.3f}")
|
| 814 |
+
|
| 815 |
+
w, OM, S8 = posterior_weights_ddpm6_marginalised(
|
| 816 |
+
obs, labs, lab_mean, lab_std, normalize, model, device,
|
| 817 |
+
lo_tail, hi_tail,
|
| 818 |
+
grid, batch_sz, ddim_steps,
|
| 819 |
+
n_pk_samples, n_marg_samples, sigma_pk,
|
| 820 |
+
)
|
| 821 |
+
summ = posterior_summary(w, OM, S8)
|
| 822 |
+
print(f" n_eff={summ['n_eff']:.0f} "
|
| 823 |
+
f"Ωm_post={summ['om_mean']:.3f}±{summ['om_std']:.3f} "
|
| 824 |
+
f"σ8_post={summ['s8_mean']:.3f}±{summ['s8_std']:.3f} "
|
| 825 |
+
f"S8={summ['S8_mean']:.3f}")
|
| 826 |
+
|
| 827 |
+
panels.append(dict(
|
| 828 |
+
w=w, OM=OM, S8=S8,
|
| 829 |
+
true_om=tom, true_s8=ts8, summary=summ,
|
| 830 |
+
title=(
|
| 831 |
+
f"test ix={ix} | "
|
| 832 |
+
r"$\Omega_m$" + f"={tom:.3f}, "
|
| 833 |
+
r"$\sigma_8$" + f"={ts8:.3f}"
|
| 834 |
+
f"\n[MC marg., N_marg={n_marg_samples}]"
|
| 835 |
+
),
|
| 836 |
+
))
|
| 837 |
+
|
| 838 |
+
if do_ppc:
|
| 839 |
+
# For PPC, use DDPM-2-style sampling (only 2 cosmological params)
|
| 840 |
+
# with a random draw from the astrophysical prior
|
| 841 |
+
rng = np.random.default_rng(ix)
|
| 842 |
+
te = rng.uniform(lo_tail, hi_tail).astype(np.float32)
|
| 843 |
+
# Build 2D posterior weights recast to 6D labels for PPC
|
| 844 |
+
w2, OM2, S82 = w, OM, S8 # same posterior geometry
|
| 845 |
+
dk_v, log_pd, log_pks = posterior_predictive_check(
|
| 846 |
+
obs, w2, OM2, S82, model, lab_mean, lab_std, normalize,
|
| 847 |
+
device, ddim_steps,
|
| 848 |
+
)
|
| 849 |
+
ppc_data.append(dict(
|
| 850 |
+
dk=dk_v, log_pd=log_pd, log_pks=log_pks,
|
| 851 |
+
title=f"PPC test ix={ix}",
|
| 852 |
+
))
|
| 853 |
+
|
| 854 |
+
# ── posterior figure ──────────────────────────────────────────────────────
|
| 855 |
+
make_posterior_figure(
|
| 856 |
+
panels,
|
| 857 |
+
suptitle=(
|
| 858 |
+
r"DDPM-6 marginal posterior on $(\Omega_m,\,\sigma_8)$ — "
|
| 859 |
+
r"six CAMELS anchors "
|
| 860 |
+
f"[MC marginalisation, N_marg={n_marg_samples}, "
|
| 861 |
+
f"{n_pk_samples} DDPM draws/point, σ_pk={sigma_pk:.3f}]"
|
| 862 |
+
),
|
| 863 |
+
out_path=out_dir / "posterior_six_anchors_ddpm6_marginalised_corrected.png",
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
if do_ppc and ppc_data:
|
| 867 |
+
make_ppc_figure(
|
| 868 |
+
ppc_data,
|
| 869 |
+
suptitle="DDPM-6 Posterior Predictive Check — P(k) envelope vs. observed",
|
| 870 |
+
out_path=out_dir / "ppc_six_anchors_ddpm6.png",
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 875 |
+
# § 12 CLI
|
| 876 |
+
# ════════════════════════════════════════════════���════════════════════════════
|
| 877 |
+
|
| 878 |
+
def parse_args() -> argparse.Namespace:
|
| 879 |
+
p = argparse.ArgumentParser(
|
| 880 |
+
description=(
|
| 881 |
+
"Corrected six-anchor surrogate posteriors: DDPM-2 and DDPM-6.\n"
|
| 882 |
+
"See module docstring for a full list of corrections applied."
|
| 883 |
+
),
|
| 884 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 885 |
+
)
|
| 886 |
+
p.add_argument(
|
| 887 |
+
"--output-dir", type=Path,
|
| 888 |
+
default=MODELS_ROOT / "ddpm_posterior_corrected_out",
|
| 889 |
+
)
|
| 890 |
+
p.add_argument(
|
| 891 |
+
"--data-2param", type=Path,
|
| 892 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"),
|
| 893 |
+
)
|
| 894 |
+
p.add_argument(
|
| 895 |
+
"--data-6param", type=Path,
|
| 896 |
+
default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"),
|
| 897 |
+
)
|
| 898 |
+
p.add_argument(
|
| 899 |
+
"--bundle-2param", type=Path,
|
| 900 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
|
| 901 |
+
)
|
| 902 |
+
p.add_argument(
|
| 903 |
+
"--bundle-6param", type=Path,
|
| 904 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
|
| 905 |
+
)
|
| 906 |
+
p.add_argument(
|
| 907 |
+
"--split", default="test", choices=["train", "val", "test"],
|
| 908 |
+
)
|
| 909 |
+
# ── grid ──────────────────────────────────────────────────────────────────
|
| 910 |
+
p.add_argument(
|
| 911 |
+
"--grid", type=int, default=30,
|
| 912 |
+
help="Grid points per Ωm–σ8 axis (30×30=900 default, was 14×14=196).",
|
| 913 |
+
)
|
| 914 |
+
# ── sampling ──────────────────────────────────────────────────────────────
|
| 915 |
+
p.add_argument(
|
| 916 |
+
"--ddim-steps", type=int, default=50,
|
| 917 |
+
help="DDIM denoising steps per sample.",
|
| 918 |
+
)
|
| 919 |
+
p.add_argument(
|
| 920 |
+
"--batch-size", type=int, default=8,
|
| 921 |
+
help="Grid-point batch size for DDPM forward passes.",
|
| 922 |
+
)
|
| 923 |
+
p.add_argument(
|
| 924 |
+
"--n-pk-samples", type=int, default=8,
|
| 925 |
+
help=(
|
| 926 |
+
"DDPM draws to average per grid point. "
|
| 927 |
+
"Variance ∝ 1/n_pk_samples. "
|
| 928 |
+
"≥8 recommended; use 4 for a fast debug run."
|
| 929 |
+
),
|
| 930 |
+
)
|
| 931 |
+
p.add_argument(
|
| 932 |
+
"--n-marg-samples", type=int, default=20,
|
| 933 |
+
help=(
|
| 934 |
+
"MC draws for DDPM-6 astrophysical marginalisation. "
|
| 935 |
+
"≥20 recommended; use 5 for a fast debug run."
|
| 936 |
+
),
|
| 937 |
+
)
|
| 938 |
+
# ── sigma calibration ─────────────────────────────────────────────────────
|
| 939 |
+
p.add_argument(
|
| 940 |
+
"--n-calib-pairs", type=int, default=30,
|
| 941 |
+
help="Number of image pairs used to calibrate sigma_pk.",
|
| 942 |
+
)
|
| 943 |
+
p.add_argument(
|
| 944 |
+
"--sigma-pk", type=float, default=None,
|
| 945 |
+
help=(
|
| 946 |
+
"Override calibrated sigma_pk with a fixed value. "
|
| 947 |
+
"Leave unset to use automatic calibration (recommended)."
|
| 948 |
+
),
|
| 949 |
+
)
|
| 950 |
+
# ── scope ─────────────────────────────────────────────────────────────────
|
| 951 |
+
p.add_argument(
|
| 952 |
+
"--ddpm2-only", action="store_true",
|
| 953 |
+
help="Only run DDPM-2 (skip loading DDPM-6).",
|
| 954 |
+
)
|
| 955 |
+
p.add_argument(
|
| 956 |
+
"--ddpm6-only", action="store_true",
|
| 957 |
+
help="Only run DDPM-6 (skip loading DDPM-2).",
|
| 958 |
+
)
|
| 959 |
+
p.add_argument(
|
| 960 |
+
"--no-ppc", action="store_true",
|
| 961 |
+
help="Skip posterior predictive check figures.",
|
| 962 |
+
)
|
| 963 |
+
return p.parse_args()
|
| 964 |
+
|
| 965 |
+
|
| 966 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 967 |
+
# § 13 MAIN
|
| 968 |
+
# ═════════════════════════════════════════════════════════════════════════════
|
| 969 |
+
|
| 970 |
+
def main() -> None:
|
| 971 |
+
args = parse_args()
|
| 972 |
+
|
| 973 |
+
if args.ddpm2_only and args.ddpm6_only:
|
| 974 |
+
raise SystemExit("Specify at most one of --ddpm2-only / --ddpm6-only.")
|
| 975 |
+
|
| 976 |
+
out_dir = Path(args.output_dir).resolve()
|
| 977 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 978 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 979 |
+
print(f"Device : {device}")
|
| 980 |
+
print(f"Output : {out_dir}")
|
| 981 |
+
print()
|
| 982 |
+
|
| 983 |
+
# ── load data ─────────────────────────────────────────────────────────────
|
| 984 |
+
data2 = Path(args.data_2param)
|
| 985 |
+
data6 = Path(args.data_6param)
|
| 986 |
+
|
| 987 |
+
if not args.ddpm6_only:
|
| 988 |
+
imgs2, labs2 = ec.load_split(data2, args.split)
|
| 989 |
+
mean2, std2 = ec.load_label_stats(data2)
|
| 990 |
+
print(f"DDPM-2 {args.split} set : {len(labs2)} maps "
|
| 991 |
+
f"label_dim={labs2.shape[1]}")
|
| 992 |
+
|
| 993 |
+
if not args.ddpm2_only:
|
| 994 |
+
imgs6, labs6 = ec.load_split(data6, args.split)
|
| 995 |
+
mean6, std6 = ec.load_label_stats(data6)
|
| 996 |
+
lo_tail, hi_tail = tail_lhs_bounds(data6)
|
| 997 |
+
print(f"DDPM-6 {args.split} set : {len(labs6)} maps "
|
| 998 |
+
f"label_dim={labs6.shape[1]}")
|
| 999 |
+
print(f" LHS tails (dims 2-5): min={lo_tail} max={hi_tail}")
|
| 1000 |
+
|
| 1001 |
+
# ── six anchors ───────────────────────────────────────────────────────────
|
| 1002 |
+
if not args.ddpm6_only:
|
| 1003 |
+
n_ref = len(labs2)
|
| 1004 |
+
else:
|
| 1005 |
+
n_ref = len(labs6)
|
| 1006 |
+
anchor_ix = np.linspace(0, n_ref - 1, num=6, dtype=int)
|
| 1007 |
+
print(f"\nAnchor indices: {anchor_ix.tolist()}\n")
|
| 1008 |
+
|
| 1009 |
+
# ── checkpoints ───────────────────────────────────────────────────────────
|
| 1010 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 1011 |
+
args_j2 = args.bundle_2param / "args.json"
|
| 1012 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 1013 |
+
args_j6 = args.bundle_6param / "args.json"
|
| 1014 |
+
|
| 1015 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 1016 |
+
# DDPM-2 BLOCK
|
| 1017 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 1018 |
+
if not args.ddpm6_only:
|
| 1019 |
+
print("=" * 70)
|
| 1020 |
+
print(">>> DDPM-2 (six anchors)")
|
| 1021 |
+
print("=" * 70)
|
| 1022 |
+
|
| 1023 |
+
model2, cfg2 = load_model(args_j2, ck2, device)
|
| 1024 |
+
|
| 1025 |
+
# ── sigma_pk calibration ──────────────────────────────────────────────
|
| 1026 |
+
if args.sigma_pk is not None:
|
| 1027 |
+
sigma2 = args.sigma_pk
|
| 1028 |
+
print(f" sigma_pk overridden to {sigma2:.4f}")
|
| 1029 |
+
else:
|
| 1030 |
+
print(" Calibrating sigma_pk from validation set …")
|
| 1031 |
+
imgs2_val, labs2_val = ec.load_split(data2, "val")
|
| 1032 |
+
sigma2 = calibrate_sigma_pk(
|
| 1033 |
+
model2, imgs2_val, labs2_val,
|
| 1034 |
+
mean2, std2,
|
| 1035 |
+
normalize=bool(cfg2.get("normalize_labels", True)),
|
| 1036 |
+
device=device,
|
| 1037 |
+
ddim_steps=args.ddim_steps,
|
| 1038 |
+
n_pairs=args.n_calib_pairs,
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
run_ddpm2(
|
| 1042 |
+
out_dir=out_dir,
|
| 1043 |
+
imgs=imgs2, labs=labs2,
|
| 1044 |
+
lab_mean=mean2, lab_std=std2,
|
| 1045 |
+
cfg=cfg2, model=model2, device=device,
|
| 1046 |
+
anchor_ix=anchor_ix,
|
| 1047 |
+
grid=args.grid,
|
| 1048 |
+
ddim_steps=args.ddim_steps,
|
| 1049 |
+
batch_sz=args.batch_size,
|
| 1050 |
+
n_pk_samples=args.n_pk_samples,
|
| 1051 |
+
sigma_pk=sigma2,
|
| 1052 |
+
do_ppc=not args.no_ppc,
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
del model2
|
| 1056 |
+
gc.collect()
|
| 1057 |
+
if torch.cuda.is_available():
|
| 1058 |
+
torch.cuda.empty_cache()
|
| 1059 |
+
print()
|
| 1060 |
+
|
| 1061 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 1062 |
+
# DDPM-6 BLOCK
|
| 1063 |
+
# ══════════════════════════════════════════════════════════════════════════
|
| 1064 |
+
if not args.ddpm2_only:
|
| 1065 |
+
print("=" * 70)
|
| 1066 |
+
print(">>> DDPM-6 (six anchors, MC marginalisation over dims 2-5)")
|
| 1067 |
+
print("=" * 70)
|
| 1068 |
+
|
| 1069 |
+
model6, cfg6 = load_model(args_j6, ck6, device)
|
| 1070 |
+
|
| 1071 |
+
# ── sigma_pk calibration ──────────────────────────────────────────────
|
| 1072 |
+
if args.sigma_pk is not None:
|
| 1073 |
+
sigma6 = args.sigma_pk
|
| 1074 |
+
print(f" sigma_pk overridden to {sigma6:.4f}")
|
| 1075 |
+
else:
|
| 1076 |
+
print(" Calibrating sigma_pk from validation set …")
|
| 1077 |
+
imgs6_val, labs6_val = ec.load_split(data6, "val")
|
| 1078 |
+
sigma6 = calibrate_sigma_pk(
|
| 1079 |
+
model6, imgs6_val, labs6_val,
|
| 1080 |
+
mean6, std6,
|
| 1081 |
+
normalize=bool(cfg6.get("normalize_labels", True)),
|
| 1082 |
+
device=device,
|
| 1083 |
+
ddim_steps=args.ddim_steps,
|
| 1084 |
+
n_pairs=args.n_calib_pairs,
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
run_ddpm6(
|
| 1088 |
+
out_dir=out_dir,
|
| 1089 |
+
imgs=imgs6, labs=labs6,
|
| 1090 |
+
lab_mean=mean6, lab_std=std6,
|
| 1091 |
+
cfg=cfg6, model=model6, device=device,
|
| 1092 |
+
lo_tail=lo_tail, hi_tail=hi_tail,
|
| 1093 |
+
anchor_ix=anchor_ix,
|
| 1094 |
+
grid=args.grid,
|
| 1095 |
+
ddim_steps=args.ddim_steps,
|
| 1096 |
+
batch_sz=args.batch_size,
|
| 1097 |
+
n_pk_samples=args.n_pk_samples,
|
| 1098 |
+
n_marg_samples=args.n_marg_samples,
|
| 1099 |
+
sigma_pk=sigma6,
|
| 1100 |
+
do_ppc=not args.no_ppc,
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
del model6
|
| 1104 |
+
gc.collect()
|
| 1105 |
+
if torch.cuda.is_available():
|
| 1106 |
+
torch.cuda.empty_cache()
|
| 1107 |
+
|
| 1108 |
+
print(f"\nAll done. Results in {out_dir}")
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
if __name__ == "__main__":
|
| 1112 |
+
main()
|
cross_model/run_compare_posterior.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=06:00:00
|
| 8 |
+
#SBATCH --job-name=cmp_post
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-cmp-post-%j.out
|
| 11 |
+
#SBATCH --error=slurm-cmp-post-%j.err
|
| 12 |
+
|
| 13 |
+
# DDPM-2 vs DDPM-6 corrected posteriors with JOINT P(k) + PDF likelihood.
|
| 14 |
+
#
|
| 15 |
+
# Defaults: 30x30 grid, 4 anchors, 8 DDPM draws / pt, 20 marg draws.
|
| 16 |
+
# Submit:
|
| 17 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/run_compare_posterior.sh
|
| 18 |
+
#
|
| 19 |
+
# Smoke test (much faster):
|
| 20 |
+
# sbatch run_compare_posterior.sh --grid 16 --n-pk-samples 4 \
|
| 21 |
+
# --n-marg-samples 5 --n-anchors 2
|
| 22 |
+
|
| 23 |
+
set -euo pipefail
|
| 24 |
+
|
| 25 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 26 |
+
cd "$ROOT"
|
| 27 |
+
|
| 28 |
+
module load python/miniconda3-py3.12-usr
|
| 29 |
+
|
| 30 |
+
PY="${ROOT}/compare_posterior_inference.py"
|
| 31 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_compare_pk_pdf_out}"
|
| 32 |
+
|
| 33 |
+
mkdir -p "${OUT}"
|
| 34 |
+
RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
|
| 35 |
+
|
| 36 |
+
echo "==============================================="
|
| 37 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 38 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 39 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 40 |
+
echo "Started: $(date)"
|
| 41 |
+
echo "Script: ${PY}"
|
| 42 |
+
echo "Output: ${OUT}"
|
| 43 |
+
echo "Log: ${RUN_LOG}"
|
| 44 |
+
echo "==============================================="
|
| 45 |
+
|
| 46 |
+
set -o pipefail
|
| 47 |
+
python -u "${PY}" --output-dir "${OUT}" "$@" 2>&1 | tee -a "${RUN_LOG}"
|
| 48 |
+
|
| 49 |
+
echo "==============================================="
|
| 50 |
+
echo "Finished: $(date)"
|
| 51 |
+
echo "Artifacts: ${OUT}"
|
| 52 |
+
echo "==============================================="
|
cross_model/run_vlb_inference_1000grid.sh
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --job-name=vlb_infer_1000
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-vlb-infer-1000-%j.out
|
| 11 |
+
#SBATCH --error=slurm-vlb-infer-1000-%j.err
|
| 12 |
+
|
| 13 |
+
# VLB / Mudur-style posterior_inference.py — 1000×1000 grid (high-resolution posteriors).
|
| 14 |
+
#
|
| 15 |
+
# High-resolution parameter grids with 9 fields, generates posterior_L0_mosaic_3x3.png
|
| 16 |
+
# plus field0X_combined.png (contours + L_0 posterior side-by-side).
|
| 17 |
+
#
|
| 18 |
+
# Submit:
|
| 19 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_vlb_inference_1000grid.sh
|
| 20 |
+
#
|
| 21 |
+
# Default: grid_size=1000, n_fields=9, span=0.10, generates mosaic + combined figures.
|
| 22 |
+
# Est. runtime: ~18-24 hours on L40S.
|
| 23 |
+
#
|
| 24 |
+
# Override via --export or command-line args:
|
| 25 |
+
# sbatch --export=OUTPUT_DIR=/path/custom_output scripts/run_vlb_inference_1000grid.sh
|
| 26 |
+
# sbatch scripts/run_vlb_inference_1000grid.sh --grid_size 500 --n_fields 4 --batch_size 16
|
| 27 |
+
#
|
| 28 |
+
# For grid_size < 300, omit --allow_huge_grid in args below.
|
| 29 |
+
#
|
| 30 |
+
# Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt.
|
| 31 |
+
|
| 32 |
+
set -euo pipefail
|
| 33 |
+
|
| 34 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 35 |
+
cd "$ROOT"
|
| 36 |
+
|
| 37 |
+
module load python/miniconda3-py3.12-usr
|
| 38 |
+
|
| 39 |
+
PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
|
| 40 |
+
OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs_1000grid}"
|
| 41 |
+
|
| 42 |
+
CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
|
| 43 |
+
ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
|
| 44 |
+
DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 45 |
+
|
| 46 |
+
mkdir -p "${OUT}"
|
| 47 |
+
RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
|
| 48 |
+
|
| 49 |
+
echo "==============================================="
|
| 50 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 51 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 52 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 53 |
+
echo "Started: $(date)"
|
| 54 |
+
echo "Python: ${PY}"
|
| 55 |
+
echo "checkpoint: ${CHK}"
|
| 56 |
+
echo "training_args: ${ARGS}"
|
| 57 |
+
echo "data_dir: ${DATA}"
|
| 58 |
+
echo "output_dir: ${OUT}"
|
| 59 |
+
echo "Progress log: ${RUN_LOG}"
|
| 60 |
+
echo "==============================================="
|
| 61 |
+
|
| 62 |
+
set -o pipefail
|
| 63 |
+
python -u "${PY}" \
|
| 64 |
+
--checkpoint "${CHK}" \
|
| 65 |
+
--training_args "${ARGS}" \
|
| 66 |
+
--data_dir "${DATA}" \
|
| 67 |
+
--output_dir "${OUT}" \
|
| 68 |
+
--grid_size 1000 \
|
| 69 |
+
--allow_huge_grid \
|
| 70 |
+
--n_fields 9 \
|
| 71 |
+
--span 0.10 \
|
| 72 |
+
--t_subset 0 1 2 5 8 10 15 20 \
|
| 73 |
+
--n_seeds 4 \
|
| 74 |
+
--batch_size 32 \
|
| 75 |
+
--seed 42 \
|
| 76 |
+
"$@" 2>&1 | tee -a "${RUN_LOG}"
|
| 77 |
+
|
| 78 |
+
echo "==============================================="
|
| 79 |
+
echo "Finished: $(date)"
|
| 80 |
+
echo "Artifacts → ${OUT}"
|
| 81 |
+
echo "==============================================="
|
cross_model/run_vlb_inference_200grid.sh
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
#SBATCH --job-name=vlb_infer_200
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-vlb-infer-200-%j.out
|
| 11 |
+
#SBATCH --error=slurm-vlb-infer-200-%j.err
|
| 12 |
+
|
| 13 |
+
# VLB / Mudur-style posterior_inference.py — 200×200 grid (balanced speed/quality).
|
| 14 |
+
#
|
| 15 |
+
# Medium-resolution parameter grids with 9 fields, generates posterior_L0_mosaic_3x3.png
|
| 16 |
+
# plus field0X_combined.png (contours + L_0 posterior side-by-side).
|
| 17 |
+
#
|
| 18 |
+
# Grid: 200×200 = 40K points per timestep (vs 1000×1000 = 1M points)
|
| 19 |
+
# Est. runtime: ~2-3 hours on L40S.
|
| 20 |
+
#
|
| 21 |
+
# Submit:
|
| 22 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/run_vlb_inference_200grid.sh
|
| 23 |
+
#
|
| 24 |
+
# Override via --export or command-line args:
|
| 25 |
+
# sbatch --export=OUTPUT_DIR=/path/custom_output run_vlb_inference_200grid.sh
|
| 26 |
+
# sbatch run_vlb_inference_200grid.sh --grid_size 150 --n_fields 4
|
| 27 |
+
#
|
| 28 |
+
# Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt.
|
| 29 |
+
|
| 30 |
+
set -euo pipefail
|
| 31 |
+
|
| 32 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 33 |
+
cd "$ROOT"
|
| 34 |
+
|
| 35 |
+
module load python/miniconda3-py3.12-usr
|
| 36 |
+
|
| 37 |
+
PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
|
| 38 |
+
OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs_200grid}"
|
| 39 |
+
|
| 40 |
+
CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
|
| 41 |
+
ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
|
| 42 |
+
DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 43 |
+
|
| 44 |
+
mkdir -p "${OUT}"
|
| 45 |
+
RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
|
| 46 |
+
|
| 47 |
+
echo "==============================================="
|
| 48 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 49 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 50 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 51 |
+
echo "Started: $(date)"
|
| 52 |
+
echo "Python: ${PY}"
|
| 53 |
+
echo "checkpoint: ${CHK}"
|
| 54 |
+
echo "training_args: ${ARGS}"
|
| 55 |
+
echo "data_dir: ${DATA}"
|
| 56 |
+
echo "output_dir: ${OUT}"
|
| 57 |
+
echo "Progress log: ${RUN_LOG}"
|
| 58 |
+
echo "==============================================="
|
| 59 |
+
|
| 60 |
+
set -o pipefail
|
| 61 |
+
python -u "${PY}" \
|
| 62 |
+
--checkpoint "${CHK}" \
|
| 63 |
+
--training_args "${ARGS}" \
|
| 64 |
+
--data_dir "${DATA}" \
|
| 65 |
+
--output_dir "${OUT}" \
|
| 66 |
+
--grid_size 200 \
|
| 67 |
+
--n_fields 9 \
|
| 68 |
+
--span 0.10 \
|
| 69 |
+
--t_subset 0 1 2 5 8 10 15 20 \
|
| 70 |
+
--n_seeds 4 \
|
| 71 |
+
--batch_size 32 \
|
| 72 |
+
--seed 42 \
|
| 73 |
+
"$@" 2>&1 | tee -a "${RUN_LOG}"
|
| 74 |
+
|
| 75 |
+
echo "==============================================="
|
| 76 |
+
echo "Finished: $(date)"
|
| 77 |
+
echo "Artifacts → ${OUT}"
|
| 78 |
+
echo "==============================================="
|
cross_model/scripts/compare_ddpm_models.py
ADDED
|
@@ -0,0 +1,855 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Compare 2-parameter and 6-parameter conditional DDPMs (CAMELS LH) side-by-side:
|
| 4 |
+
• Random-draw vs test-conditioned triplets (CAMELS | DDPM-2 | DDPM-6)
|
| 5 |
+
• Six anchor cosmologies: P(k) and PDF diagnostics (triple curves per panel where applicable)
|
| 6 |
+
• LHS R² cosmology plots (LHS-50 × 15 maps — expensive)
|
| 7 |
+
• MLP P(k) → label recovery ( sklearn MLP, two models + shared CAMELS calibration )
|
| 8 |
+
• Surrogate posterior on (Ωm, σ8) for a fixed test index
|
| 9 |
+
• Training / validation loss on one axis (Slurm .out for DDPM-6; DDPM-2 defaults to bundled JSON)
|
| 10 |
+
|
| 11 |
+
Outputs under --output-dir (default: Models/ddpm_comparison_out/).
|
| 12 |
+
|
| 13 |
+
GPU: both models are resident while generating comparison panels; use a single GPU with
|
| 14 |
+
sufficient memory, or run heavier steps separately with refactors.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import gc
|
| 21 |
+
import sys
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, Sequence, Tuple
|
| 24 |
+
|
| 25 |
+
import matplotlib
|
| 26 |
+
|
| 27 |
+
matplotlib.use("Agg")
|
| 28 |
+
import matplotlib.pyplot as plt
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
|
| 32 |
+
# --- Repo paths ---
|
| 33 |
+
MODELS_ROOT = Path(__file__).resolve().parents[1]
|
| 34 |
+
CODE_6 = (MODELS_ROOT / "6param_ddpm_hi_lh6").resolve()
|
| 35 |
+
if str(CODE_6) not in sys.path:
|
| 36 |
+
sys.path.insert(0, str(CODE_6))
|
| 37 |
+
|
| 38 |
+
import evaluate_conditional as ec # noqa: E402
|
| 39 |
+
import eval_model as em # noqa: E402
|
| 40 |
+
from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402
|
| 41 |
+
|
| 42 |
+
from plot_r2_cosmology_lhs import compute_lhs_r2, plot_r2_cosmology_figure # noqa: E402
|
| 43 |
+
|
| 44 |
+
from compare_ddpm_training_curves import ( # noqa: E402
|
| 45 |
+
load_train_val_series,
|
| 46 |
+
parse_slurm_training_log,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
DEFAULT_SLURM_6 = Path(
|
| 50 |
+
"/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/scripts/shell/slurm-698243.out"
|
| 51 |
+
)
|
| 52 |
+
# Bundled train/val (no 2-param Slurm log in-repo); see ``ddpm_2param_training_loss.json``.
|
| 53 |
+
DEFAULT_DDPM2_TRAINING = (Path(__file__).resolve().parent / "ddpm_2param_training_loss.json")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _fmt_title(lab: np.ndarray) -> str:
|
| 57 |
+
t = np.asarray(lab, dtype=float).ravel()
|
| 58 |
+
if t.size <= 2:
|
| 59 |
+
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}"
|
| 60 |
+
tail = ", ".join(f"{float(v):.3g}" for v in t[2:])
|
| 61 |
+
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _latin_hypercube(n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
| 65 |
+
"""Classic LHS (same as notebook)."""
|
| 66 |
+
d = int(lo.shape[0])
|
| 67 |
+
u = rng.random((n, d))
|
| 68 |
+
cut = np.linspace(0.0, 1.0, n + 1)
|
| 69 |
+
a, b = cut[:-1], cut[1:]
|
| 70 |
+
width = (b - a)[:, np.newaxis]
|
| 71 |
+
rd = a[:, np.newaxis] + u * width
|
| 72 |
+
for j in range(d):
|
| 73 |
+
rng.shuffle(rd[:, j])
|
| 74 |
+
span = (hi - lo).astype(np.float64)
|
| 75 |
+
return (lo + rd * span).astype(np.float32)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@torch.no_grad()
|
| 79 |
+
def generate_maps(
|
| 80 |
+
model: torch.nn.Module,
|
| 81 |
+
labels_np: np.ndarray,
|
| 82 |
+
label_mean: np.ndarray,
|
| 83 |
+
label_std: np.ndarray,
|
| 84 |
+
H: int,
|
| 85 |
+
W: int,
|
| 86 |
+
device: torch.device,
|
| 87 |
+
ddim_steps: int,
|
| 88 |
+
batch_size: int,
|
| 89 |
+
) -> np.ndarray:
|
| 90 |
+
out = []
|
| 91 |
+
n = labels_np.shape[0]
|
| 92 |
+
for j0 in range(0, n, batch_size):
|
| 93 |
+
chunk = labels_np[j0 : j0 + batch_size]
|
| 94 |
+
bt = ec.prepare_labels_for_model(chunk.astype(np.float32), label_mean, label_std).to(device)
|
| 95 |
+
g = model.sample(
|
| 96 |
+
labels=bt,
|
| 97 |
+
channels=1,
|
| 98 |
+
height=H,
|
| 99 |
+
width=W,
|
| 100 |
+
device=device,
|
| 101 |
+
progress=False,
|
| 102 |
+
use_ddim=True,
|
| 103 |
+
ddim_steps=ddim_steps,
|
| 104 |
+
)
|
| 105 |
+
out.append(ec.from_model_output(g))
|
| 106 |
+
return np.concatenate(out, axis=0)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_model(bundle_args: Path, ckpt: Path, device: torch.device):
|
| 110 |
+
cfg = ec.load_training_config(str(bundle_args))
|
| 111 |
+
model = ec.build_model(cfg, device)
|
| 112 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 113 |
+
model.eval()
|
| 114 |
+
return model, cfg
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def free_torch():
|
| 118 |
+
gc.collect()
|
| 119 |
+
if torch.cuda.is_available():
|
| 120 |
+
torch.cuda.empty_cache()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def plot_training_overlay(
|
| 124 |
+
out_dir: Path,
|
| 125 |
+
slurm6: Path | None,
|
| 126 |
+
slurm2: Path | None,
|
| 127 |
+
) -> None:
|
| 128 |
+
"""Train + val curves for DDPM6 and optionally DDPM2 on one logarithmic-loss axis."""
|
| 129 |
+
fig, ax = plt.subplots(figsize=(9, 5))
|
| 130 |
+
plotted = False
|
| 131 |
+
if slurm6 and Path(slurm6).is_file():
|
| 132 |
+
ep, tr, va = parse_slurm_training_log(slurm6)
|
| 133 |
+
ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-6 train", color="#1f77b4", alpha=0.85)
|
| 134 |
+
ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-6 val", color="#174a75", alpha=0.95)
|
| 135 |
+
plotted = True
|
| 136 |
+
else:
|
| 137 |
+
print("Warning: 6-param Slurm log not found; skipped overlay for DDPM-6.")
|
| 138 |
+
|
| 139 |
+
if slurm2 and Path(slurm2).is_file():
|
| 140 |
+
ep, tr, va = load_train_val_series(slurm2)
|
| 141 |
+
ax.plot(ep, tr, lw=1.4, ls="-", label="DDPM-2 train", color="#ff7f0e", alpha=0.85)
|
| 142 |
+
ax.plot(ep, va, lw=1.8, ls="--", label="DDPM-2 val", color="#994d00", alpha=0.95)
|
| 143 |
+
plotted = True
|
| 144 |
+
elif slurm2 is not None:
|
| 145 |
+
print(f"Warning: 2-param training series not found ({slurm2}); use --slurm-2param or restore bundled JSON.")
|
| 146 |
+
|
| 147 |
+
if not plotted:
|
| 148 |
+
print("No Slurm logs parsed — writing placeholder note instead of curves.")
|
| 149 |
+
ax.text(
|
| 150 |
+
0.5,
|
| 151 |
+
0.5,
|
| 152 |
+
"Pass --slurm-6param; DDPM-2 uses bundled JSON by default (--slurm-2param).",
|
| 153 |
+
ha="center",
|
| 154 |
+
va="center",
|
| 155 |
+
transform=ax.transAxes,
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
ax.set_yscale("log")
|
| 159 |
+
ax.grid(True, alpha=0.3)
|
| 160 |
+
ax.set_xlabel("Epoch")
|
| 161 |
+
ax.set_ylabel("MSE diffusion loss")
|
| 162 |
+
ax.legend(loc="upper right", fontsize=8)
|
| 163 |
+
ax.set_title("Training / validation curves (combined)")
|
| 164 |
+
outp = out_dir / "comparison_training_train_val_overlay.png"
|
| 165 |
+
fig.savefig(outp, dpi=170, bbox_inches="tight")
|
| 166 |
+
plt.close(fig)
|
| 167 |
+
print("Saved", outp)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def run_random_theta_triplets(
|
| 171 |
+
out_dir: Path,
|
| 172 |
+
imgs6: np.ndarray,
|
| 173 |
+
lab6: np.ndarray,
|
| 174 |
+
mean6: np.ndarray,
|
| 175 |
+
std6: np.ndarray,
|
| 176 |
+
mean2: np.ndarray,
|
| 177 |
+
std2: np.ndarray,
|
| 178 |
+
model2,
|
| 179 |
+
model6,
|
| 180 |
+
device: torch.device,
|
| 181 |
+
ddim_steps: int,
|
| 182 |
+
seed: int,
|
| 183 |
+
n_pairs: int,
|
| 184 |
+
batch_size: int,
|
| 185 |
+
):
|
| 186 |
+
"""Random LHS targets in CAMELS bbox; CAMELS column = NN real map."""
|
| 187 |
+
rng = np.random.default_rng(seed)
|
| 188 |
+
lo, hi = lab6.min(0), lab6.max(0)
|
| 189 |
+
targets = _latin_hypercube(min(n_pairs, 32), lo, hi, rng)[:n_pairs]
|
| 190 |
+
|
| 191 |
+
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
|
| 192 |
+
tg2 = targets[:, :2].astype(np.float32)
|
| 193 |
+
|
| 194 |
+
fig = plt.figure(figsize=(3.8 * max(3, n_pairs * 3), 4.1))
|
| 195 |
+
for i in range(n_pairs):
|
| 196 |
+
theta6 = targets[i].astype(np.float32)
|
| 197 |
+
theta2 = tg2[i]
|
| 198 |
+
dist = np.linalg.norm(lab6 - theta6[None, :], axis=1).astype(np.float64)
|
| 199 |
+
nn = int(np.argmin(dist))
|
| 200 |
+
nn_img = imgs6[nn]
|
| 201 |
+
|
| 202 |
+
gen2 = generate_maps(
|
| 203 |
+
model2, theta2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size
|
| 204 |
+
)
|
| 205 |
+
gen6 = generate_maps(
|
| 206 |
+
model6, theta6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
titles = ("CAMELS (NN)", "DDPM-2", "DDPM-6")
|
| 210 |
+
for j, img in enumerate((nn_img, gen2[0], gen6[0])):
|
| 211 |
+
ax = fig.add_subplot(1, n_pairs * 3, i * 3 + j + 1)
|
| 212 |
+
ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
|
| 213 |
+
ax.axis("off")
|
| 214 |
+
ax.set_title(titles[j], fontsize=8)
|
| 215 |
+
|
| 216 |
+
plt.suptitle(
|
| 217 |
+
"Random LHS cosmologies — CAMELS = nearest-neighbour truth | gens conditioned on LHS labels",
|
| 218 |
+
fontsize=10,
|
| 219 |
+
y=1.02,
|
| 220 |
+
)
|
| 221 |
+
p = out_dir / "comparison_random_lhs_triplets_camels_ddpm2_ddpm6.png"
|
| 222 |
+
plt.tight_layout()
|
| 223 |
+
plt.savefig(p, dpi=160, bbox_inches="tight")
|
| 224 |
+
plt.close(fig)
|
| 225 |
+
print("Saved", p)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def run_conditioned_test_triplets(
|
| 229 |
+
out_dir: Path,
|
| 230 |
+
imgs6: np.ndarray,
|
| 231 |
+
lab6: np.ndarray,
|
| 232 |
+
mean6: np.ndarray,
|
| 233 |
+
std6: np.ndarray,
|
| 234 |
+
mean2: np.ndarray,
|
| 235 |
+
std2: np.ndarray,
|
| 236 |
+
model2,
|
| 237 |
+
model6,
|
| 238 |
+
device: torch.device,
|
| 239 |
+
ddim_steps: int,
|
| 240 |
+
seed: int,
|
| 241 |
+
n_pairs: int,
|
| 242 |
+
batch_size: int,
|
| 243 |
+
):
|
| 244 |
+
"""Same rows from test split: conditioned on truth labels."""
|
| 245 |
+
rng = np.random.default_rng(seed + 1)
|
| 246 |
+
idx = rng.choice(len(imgs6), size=min(n_pairs, len(imgs6)), replace=False)
|
| 247 |
+
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
|
| 248 |
+
fig, axes = plt.subplots(1, n_pairs * 3, figsize=(2.9 * n_pairs * 3, 3.8), squeeze=False)
|
| 249 |
+
for ii, ix in enumerate(idx):
|
| 250 |
+
tg6 = lab6[ix].astype(np.float32)
|
| 251 |
+
tg2 = tg6[:2]
|
| 252 |
+
rm = imgs6[ix]
|
| 253 |
+
g2 = generate_maps(model2, tg2[np.newaxis, :], mean2, std2, H, W, device, ddim_steps, batch_size)[0]
|
| 254 |
+
g6 = generate_maps(model6, tg6[np.newaxis, :], mean6, std6, H, W, device, ddim_steps, batch_size)[0]
|
| 255 |
+
for j, img in enumerate((rm, g2, g6)):
|
| 256 |
+
ax = axes[0, ii * 3 + j]
|
| 257 |
+
ax.imshow(img, vmin=0, vmax=1, origin="lower", cmap="inferno")
|
| 258 |
+
ax.axis("off")
|
| 259 |
+
if ii == 0:
|
| 260 |
+
ax.set_title(("CAMELS", "DDPM-2", "DDPM-6")[j], fontsize=8)
|
| 261 |
+
axes[0, ii * 3].set_xlabel(_fmt_title(tg6), fontsize=7)
|
| 262 |
+
plt.suptitle(f"Random test ix (conditioned on truth labels), n={len(idx)}", fontsize=10, y=1.06)
|
| 263 |
+
p = out_dir / "comparison_test_conditioned_camels_ddpm2_ddpm6.png"
|
| 264 |
+
plt.savefig(p, dpi=160, bbox_inches="tight")
|
| 265 |
+
plt.close(fig)
|
| 266 |
+
print("Saved", p)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def pk_pdf_six_sets(
|
| 270 |
+
out_dir: Path,
|
| 271 |
+
name: str,
|
| 272 |
+
images_split: np.ndarray,
|
| 273 |
+
labels_split: np.ndarray,
|
| 274 |
+
label_mean: np.ndarray,
|
| 275 |
+
label_std: np.ndarray,
|
| 276 |
+
model,
|
| 277 |
+
device: torch.device,
|
| 278 |
+
ddim_steps: int,
|
| 279 |
+
batch_size: int,
|
| 280 |
+
n_per_set: int,
|
| 281 |
+
):
|
| 282 |
+
"""Six anchor rows (evenly spaced ix in test split), N_PER_SET DDIM samples."""
|
| 283 |
+
H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
|
| 284 |
+
ldim = int(labels_split.shape[1])
|
| 285 |
+
|
| 286 |
+
idx = np.linspace(0, len(labels_split) - 1, num=6, dtype=int)
|
| 287 |
+
targets = labels_split[idx].copy()
|
| 288 |
+
|
| 289 |
+
box = 25.0
|
| 290 |
+
dk_ref = None
|
| 291 |
+
panels_pk = []
|
| 292 |
+
|
| 293 |
+
rng_pdf_bins = np.linspace(14.0, 22.0, 101)
|
| 294 |
+
bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
|
| 295 |
+
|
| 296 |
+
fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
|
| 297 |
+
axes_pk = axes_pk.ravel()
|
| 298 |
+
|
| 299 |
+
fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 4.8 * 2), squeeze=False)
|
| 300 |
+
|
| 301 |
+
for si, target_l in enumerate(targets):
|
| 302 |
+
dist = np.linalg.norm(labels_split - target_l[None, :], axis=1).astype(np.float64)
|
| 303 |
+
ex = idx[si]
|
| 304 |
+
dist[ex] = np.inf if ex < len(dist) else np.inf
|
| 305 |
+
nn_idx = np.argsort(dist)[:n_per_set]
|
| 306 |
+
real_batch = images_split[nn_idx]
|
| 307 |
+
rep = np.tile(target_l[None, :], (n_per_set, 1))
|
| 308 |
+
gen = generate_maps(model, rep, label_mean, label_std, H, W, device, ddim_steps, batch_size)
|
| 309 |
+
|
| 310 |
+
dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
|
| 311 |
+
dk_g, mg, sg = ec.calculate_power_spectrum_batch(gen, box_size=box)
|
| 312 |
+
dk_ref = dk_r
|
| 313 |
+
x = dk_ref[1:]
|
| 314 |
+
axpk = axes_pk[si]
|
| 315 |
+
axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#333")
|
| 316 |
+
axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.08, color="#333")
|
| 317 |
+
axpk.plot(x, mg[1:], lw=2, label=f"Generated (ldim={ldim})", color="#d95f02")
|
| 318 |
+
axpk.fill_between(x, mg[1:] - sg[1:], mg[1:] + sg[1:], alpha=0.08, color="#d95f02")
|
| 319 |
+
axpk.set_yscale("log")
|
| 320 |
+
axpk.grid(alpha=0.25)
|
| 321 |
+
axpk.set_title(_fmt_title(target_l), fontsize=8)
|
| 322 |
+
panels_pk.append((si, dk_r, mr, sr, mg, sg))
|
| 323 |
+
|
| 324 |
+
# PDF µ/σ
|
| 325 |
+
tb = []; rb = []
|
| 326 |
+
for i in range(n_per_set):
|
| 327 |
+
for arr, store in zip((real_batch, gen), (tb, rb)):
|
| 328 |
+
ims = np.clip(arr[i].ravel(), 0.0, 1.0)
|
| 329 |
+
logn = 14.0 + (22.0 - 14.0) * ims
|
| 330 |
+
hst, _ = np.histogram(logn, bins=rng_pdf_bins, density=True)
|
| 331 |
+
store.append(hst)
|
| 332 |
+
tb = np.asarray(tb); rb = np.asarray(rb)
|
| 333 |
+
|
| 334 |
+
axes_pdf[si, 0].plot(bin_pdf, tb.mean(axis=0), lw=2, label="CAMELS NN", color="#333")
|
| 335 |
+
axes_pdf[si, 0].plot(bin_pdf, rb.mean(axis=0), lw=2, label="Generated", color="#d95f02")
|
| 336 |
+
axes_pdf[si, 1].plot(bin_pdf, tb.std(axis=0), lw=2, ls="-", label="CAMELS σ", color="#333")
|
| 337 |
+
axes_pdf[si, 1].plot(bin_pdf, rb.std(axis=0), lw=2, ls="--", label="Gen σ", color="#d95f02")
|
| 338 |
+
|
| 339 |
+
axes_pk[0].legend(fontsize=7, loc="lower left")
|
| 340 |
+
fig_pk.suptitle(f"$P(k)$ mean±std — six anchors — {name}", fontsize=10)
|
| 341 |
+
fig_pk.tight_layout()
|
| 342 |
+
p_pk = out_dir / f"six_anchor_pk_{name}.png"
|
| 343 |
+
fig_pk.savefig(p_pk, dpi=160)
|
| 344 |
+
plt.close(fig_pk)
|
| 345 |
+
|
| 346 |
+
axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
|
| 347 |
+
axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
|
| 348 |
+
fig_pdf.suptitle(f"PDF mean & σ — six anchors × {n_per_set} — {name}")
|
| 349 |
+
fig_pdf.tight_layout()
|
| 350 |
+
p_pdf = out_dir / f"six_anchor_pdf_mu_sigma_{name}.png"
|
| 351 |
+
fig_pdf.savefig(p_pdf, dpi=160)
|
| 352 |
+
plt.close(fig_pdf)
|
| 353 |
+
print("Saved", p_pk)
|
| 354 |
+
print("Saved", p_pdf)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def pk_six_triplet_combined(
|
| 358 |
+
out_dir: Path,
|
| 359 |
+
imgs6: np.ndarray,
|
| 360 |
+
lab6: np.ndarray,
|
| 361 |
+
mean6: np.ndarray,
|
| 362 |
+
std6: np.ndarray,
|
| 363 |
+
mean2: np.ndarray,
|
| 364 |
+
std2: np.ndarray,
|
| 365 |
+
model2: torch.nn.Module,
|
| 366 |
+
model6: torch.nn.Module,
|
| 367 |
+
device: torch.device,
|
| 368 |
+
ddim_steps: int,
|
| 369 |
+
batch_size: int,
|
| 370 |
+
n_per_set: int,
|
| 371 |
+
) -> None:
|
| 372 |
+
"""Six anchors — mean P(k) for CAMELS vs DDPM-2 vs DDPM-6; analogous PDF overlays."""
|
| 373 |
+
H, W = int(imgs6.shape[-2]), int(imgs6.shape[-1])
|
| 374 |
+
idx = np.linspace(0, len(lab6) - 1, num=6, dtype=int)
|
| 375 |
+
targets = lab6[idx].copy()
|
| 376 |
+
box = 25.0
|
| 377 |
+
|
| 378 |
+
fig_pk, axes_pk = plt.subplots(2, 3, figsize=(14, 9), sharex=True, sharey=True)
|
| 379 |
+
axes_pk = axes_pk.ravel()
|
| 380 |
+
rng_pdf_bins = np.linspace(14.0, 22.0, 101)
|
| 381 |
+
bin_pdf = 0.5 * (rng_pdf_bins[:-1] + rng_pdf_bins[1:])
|
| 382 |
+
fig_pdf, axes_pdf = plt.subplots(6, 2, figsize=(12, 10.5))
|
| 383 |
+
|
| 384 |
+
for si, target_l in enumerate(targets):
|
| 385 |
+
dist = np.linalg.norm(lab6 - target_l[None, :], axis=1).astype(np.float64)
|
| 386 |
+
ex = int(idx[si])
|
| 387 |
+
if ex < len(dist):
|
| 388 |
+
dist = dist.copy()
|
| 389 |
+
dist[ex] = np.inf
|
| 390 |
+
nn_idx = np.argsort(dist)[:n_per_set]
|
| 391 |
+
real_batch = imgs6[nn_idx]
|
| 392 |
+
|
| 393 |
+
tg2 = np.tile(target_l[:2][None, :], (n_per_set, 1)).astype(np.float32)
|
| 394 |
+
tg6 = np.tile(target_l[None, :], (n_per_set, 1)).astype(np.float32)
|
| 395 |
+
|
| 396 |
+
gen2 = generate_maps(model2, tg2, mean2, std2, H, W, device, ddim_steps, batch_size)
|
| 397 |
+
gen6 = generate_maps(model6, tg6, mean6, std6, H, W, device, ddim_steps, batch_size)
|
| 398 |
+
|
| 399 |
+
dk_r, mr, sr = ec.calculate_power_spectrum_batch(real_batch, box_size=box)
|
| 400 |
+
_, m2, s2 = ec.calculate_power_spectrum_batch(gen2, box_size=box)
|
| 401 |
+
_, mG, sG = ec.calculate_power_spectrum_batch(gen6, box_size=box)
|
| 402 |
+
x = dk_r[1:]
|
| 403 |
+
|
| 404 |
+
axpk = axes_pk[si]
|
| 405 |
+
axpk.plot(x, mr[1:], lw=2, label="CAMELS NN", color="#222")
|
| 406 |
+
axpk.fill_between(x, mr[1:] - sr[1:], mr[1:] + sr[1:], alpha=0.06, color="#222")
|
| 407 |
+
axpk.plot(x, m2[1:], lw=2, label="DDPM-2 μ", color="#ff7f0e")
|
| 408 |
+
axpk.fill_between(x, m2[1:] - s2[1:], m2[1:] + s2[1:], alpha=0.06, color="#ff7f0e")
|
| 409 |
+
axpk.plot(x, mG[1:], lw=2, label="DDPM-6 μ", color="#1f77b4")
|
| 410 |
+
axpk.fill_between(x, mG[1:] - sG[1:], mG[1:] + sG[1:], alpha=0.06, color="#1f77b4")
|
| 411 |
+
axpk.set_yscale("log")
|
| 412 |
+
axpk.grid(alpha=0.25)
|
| 413 |
+
axpk.set_title(_fmt_title(target_l), fontsize=8)
|
| 414 |
+
if si == 0:
|
| 415 |
+
axpk.legend(fontsize=6.2, loc="lower left")
|
| 416 |
+
|
| 417 |
+
pdf_rows_lists = []
|
| 418 |
+
for imgs in (real_batch, gen2, gen6):
|
| 419 |
+
hb = []
|
| 420 |
+
for i in range(min(n_per_set, len(imgs))):
|
| 421 |
+
px = np.clip(imgs[i].ravel(), 0.0, 1.0)
|
| 422 |
+
ln = 14.0 + (22.0 - 14.0) * px
|
| 423 |
+
hb.append(np.histogram(ln, bins=rng_pdf_bins, density=True)[0])
|
| 424 |
+
pdf_rows_lists.append(np.asarray(hb))
|
| 425 |
+
|
| 426 |
+
cam_pdf, d2_pdf, d6_pdf = pdf_rows_lists
|
| 427 |
+
axes_pdf[si, 0].plot(bin_pdf, cam_pdf.mean(axis=0), lw=2, color="#222", label="CAMELS μ")
|
| 428 |
+
axes_pdf[si, 0].plot(bin_pdf, d2_pdf.mean(axis=0), lw=2, color="#ff7f0e", label="DDPM-2 μ")
|
| 429 |
+
axes_pdf[si, 0].plot(bin_pdf, d6_pdf.mean(axis=0), lw=2, color="#1f77b4", label="DDPM-6 μ")
|
| 430 |
+
axes_pdf[si, 1].plot(bin_pdf, cam_pdf.std(axis=0), lw=2, color="#222")
|
| 431 |
+
axes_pdf[si, 1].plot(bin_pdf, d2_pdf.std(axis=0), lw=2, ls="--", color="#ff7f0e")
|
| 432 |
+
axes_pdf[si, 1].plot(bin_pdf, d6_pdf.std(axis=0), lw=2, ls="--", color="#1f77b4")
|
| 433 |
+
|
| 434 |
+
fig_pk.suptitle("$P(k)$ CAMELS vs DDPM-2 vs DDPM-6 — six Ωm–σ8 anchors", fontsize=11)
|
| 435 |
+
fig_pk.tight_layout()
|
| 436 |
+
p_pk = out_dir / "six_anchor_pk_overlay_camels_ddpm2_ddpm6.png"
|
| 437 |
+
fig_pk.savefig(p_pk, dpi=160)
|
| 438 |
+
plt.close(fig_pk)
|
| 439 |
+
|
| 440 |
+
axes_pdf[-1, 0].set_xlabel(r"$\log N_{\mathrm{HI}}$")
|
| 441 |
+
axes_pdf[-1, 1].set_xlabel(r"$\log N_{\mathrm{HI}}$")
|
| 442 |
+
fig_pdf.suptitle(r"PDF mean ($\mu$) and std ($\sigma$) overlays", fontsize=10)
|
| 443 |
+
fig_pdf.tight_layout()
|
| 444 |
+
p_pdf = out_dir / "six_anchor_pdf_overlay_camels_ddpm2_ddpm6.png"
|
| 445 |
+
fig_pdf.savefig(p_pdf, dpi=160)
|
| 446 |
+
plt.close(fig_pdf)
|
| 447 |
+
print("Saved", p_pk)
|
| 448 |
+
print("Saved", p_pdf)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def mlp_recovery_dual(
|
| 452 |
+
out_dir: Path,
|
| 453 |
+
data_train: Path,
|
| 454 |
+
imgs_te: np.ndarray,
|
| 455 |
+
lab_te: np.ndarray,
|
| 456 |
+
mean: np.ndarray,
|
| 457 |
+
std: np.ndarray,
|
| 458 |
+
model_ddpm: torch.nn.Module,
|
| 459 |
+
tag: str,
|
| 460 |
+
device: torch.device,
|
| 461 |
+
ddim_steps: int,
|
| 462 |
+
seed: int,
|
| 463 |
+
) -> None:
|
| 464 |
+
from sklearn.metrics import mean_squared_error
|
| 465 |
+
from sklearn.neural_network import MLPRegressor
|
| 466 |
+
|
| 467 |
+
ldim = lab_te.shape[1]
|
| 468 |
+
Npix = imgs_te.shape[-1]
|
| 469 |
+
dl = 25.0 / Npix
|
| 470 |
+
|
| 471 |
+
def pk_row(im):
|
| 472 |
+
_dk, pk = ec.PowerSpectrum(np.asarray(im, dtype=np.float64), N=Npix, dl=dl)
|
| 473 |
+
return pk[1:].astype(np.float32)
|
| 474 |
+
|
| 475 |
+
img_tr_np, lab_tr_np = ec.load_split(data_train, "train")
|
| 476 |
+
if len(img_tr_np) > 2000:
|
| 477 |
+
rng = np.random.default_rng(seed)
|
| 478 |
+
jj = rng.choice(len(img_tr_np), 2000, replace=False)
|
| 479 |
+
img_tr_np, lab_tr_np = img_tr_np[jj], lab_tr_np[jj]
|
| 480 |
+
X_train = np.stack([pk_row(img_tr_np[i]) for i in range(len(img_tr_np))], axis=0)
|
| 481 |
+
y_train = lab_tr_np.astype(np.float32)
|
| 482 |
+
mlp = MLPRegressor(
|
| 483 |
+
hidden_layer_sizes=(64, 64),
|
| 484 |
+
alpha=1e-4,
|
| 485 |
+
random_state=seed,
|
| 486 |
+
max_iter=250,
|
| 487 |
+
early_stopping=True,
|
| 488 |
+
validation_fraction=0.1,
|
| 489 |
+
)
|
| 490 |
+
mlp.fit(X_train, y_train)
|
| 491 |
+
|
| 492 |
+
n_ev = min(40, len(imgs_te))
|
| 493 |
+
eval_idx = np.arange(n_ev)
|
| 494 |
+
X_real = np.stack([pk_row(imgs_te[i]) for i in eval_idx], axis=0)
|
| 495 |
+
y_true = lab_te[eval_idx]
|
| 496 |
+
|
| 497 |
+
preds_real = mlp.predict(X_real)
|
| 498 |
+
|
| 499 |
+
gens = []
|
| 500 |
+
H, W = int(imgs_te.shape[-2]), int(imgs_te.shape[-1])
|
| 501 |
+
for i0 in range(0, n_ev, 8):
|
| 502 |
+
bs_chunk = min(8, n_ev - i0)
|
| 503 |
+
lbl = y_true[i0 : i0 + bs_chunk]
|
| 504 |
+
g = generate_maps(model_ddpm, lbl, mean, std, H, W, device, ddim_steps, bs_chunk)
|
| 505 |
+
gens.extend([pk_row(g[j]) for j in range(len(g))])
|
| 506 |
+
X_gen = np.stack(gens, axis=0)
|
| 507 |
+
preds_gen = mlp.predict(X_gen)
|
| 508 |
+
|
| 509 |
+
rmse_real = np.sqrt(mean_squared_error(y_true, preds_real, multioutput="raw_values"))
|
| 510 |
+
rmse_gen = np.sqrt(mean_squared_error(y_true, preds_gen, multioutput="raw_values"))
|
| 511 |
+
|
| 512 |
+
fig, axes = plt.subplots(2, ldim, figsize=(max(9.0, 2.8 * max(ldim, 2)), 4.9), squeeze=False)
|
| 513 |
+
if ldim == 1:
|
| 514 |
+
axes = np.reshape(axes, (2, 1))
|
| 515 |
+
for k in range(ldim):
|
| 516 |
+
for row, preds, rmv, ylab in (
|
| 517 |
+
(0, preds_real, rmse_real, "CAMELS P(k) predictions"),
|
| 518 |
+
(1, preds_gen, rmse_gen, f"{tag}: generated P(k)"),
|
| 519 |
+
):
|
| 520 |
+
ax = axes[row, k]
|
| 521 |
+
lo = float(y_true[:, k].min()); hi = float(y_true[:, k].max())
|
| 522 |
+
pad = 0.03 * (hi - lo + 1e-12)
|
| 523 |
+
ax.scatter(y_true[:, k], preds[:, k], s=14, alpha=0.72, edgecolors="none", c="#333")
|
| 524 |
+
ax.plot([lo - pad, hi + pad], [lo - pad, hi + pad], color="crimson", lw=1.0)
|
| 525 |
+
ax.grid(True, alpha=0.28)
|
| 526 |
+
ax.set_title(f"dim {k} RMSE={float(rmv[k]):.4f}", fontsize=8)
|
| 527 |
+
if k == 0:
|
| 528 |
+
ax.set_ylabel(ylab, fontsize=8)
|
| 529 |
+
|
| 530 |
+
plt.suptitle(
|
| 531 |
+
"MLP: train on CAMELS train P(k), test on CAMELS vs DDPM-drawn spectra",
|
| 532 |
+
fontsize=10,
|
| 533 |
+
y=1.02,
|
| 534 |
+
)
|
| 535 |
+
plt.tight_layout()
|
| 536 |
+
p = out_dir / f"mlp_pk_parameter_recovery_{tag}.png"
|
| 537 |
+
plt.savefig(p, dpi=165, bbox_inches="tight")
|
| 538 |
+
plt.close(fig)
|
| 539 |
+
print("Saved", p)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def posterior_one_index(
|
| 543 |
+
out_dir: Path,
|
| 544 |
+
images_split: np.ndarray,
|
| 545 |
+
labels_split: np.ndarray,
|
| 546 |
+
lab_mean: np.ndarray,
|
| 547 |
+
lab_std: np.ndarray,
|
| 548 |
+
model,
|
| 549 |
+
cfg: Dict,
|
| 550 |
+
device,
|
| 551 |
+
ix: int,
|
| 552 |
+
tag: str,
|
| 553 |
+
ddim_steps: int,
|
| 554 |
+
grid: int,
|
| 555 |
+
batch_sz: int,
|
| 556 |
+
):
|
| 557 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 558 |
+
lab_dim = labels_split.shape[1]
|
| 559 |
+
H, W = int(images_split.shape[-2]), int(images_split.shape[-1])
|
| 560 |
+
obs = images_split[ix]
|
| 561 |
+
label_anchor_full = labels_split[ix].astype(np.float32)
|
| 562 |
+
|
| 563 |
+
lo0 = float(labels_split[:, 0].min())
|
| 564 |
+
hi0 = float(labels_split[:, 0].max())
|
| 565 |
+
lo1 = float(labels_split[:, 1].min())
|
| 566 |
+
hi1 = float(labels_split[:, 1].max())
|
| 567 |
+
pad0 = 0.02 * (hi0 - lo0 + 1e-12)
|
| 568 |
+
pad1 = 0.02 * (hi1 - lo1 + 1e-12)
|
| 569 |
+
om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid(
|
| 570 |
+
grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1
|
| 571 |
+
)
|
| 572 |
+
g = grid
|
| 573 |
+
ngrid = grid2.shape[0]
|
| 574 |
+
npix = int(obs.shape[-1])
|
| 575 |
+
dl = 25.0 / npix
|
| 576 |
+
dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl)
|
| 577 |
+
valid = dk > 0
|
| 578 |
+
log_pd = log_pk_observed(obs, 25.0, dk)
|
| 579 |
+
|
| 580 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 581 |
+
full = np.tile(label_anchor_full[np.newaxis, :], (ngrid, 1))
|
| 582 |
+
full[:, 0] = grid2[:, 0].astype(np.float32)
|
| 583 |
+
full[:, 1] = grid2[:, 1].astype(np.float32)
|
| 584 |
+
|
| 585 |
+
def weights_full() -> np.ndarray:
|
| 586 |
+
scores = []
|
| 587 |
+
for j0 in range(0, ngrid, batch_sz):
|
| 588 |
+
chunk = full[j0 : j0 + batch_sz]
|
| 589 |
+
imgs = em.sample_batch(
|
| 590 |
+
model,
|
| 591 |
+
chunk,
|
| 592 |
+
lab_mean,
|
| 593 |
+
lab_std,
|
| 594 |
+
normalize,
|
| 595 |
+
H,
|
| 596 |
+
W,
|
| 597 |
+
device,
|
| 598 |
+
ddim_steps,
|
| 599 |
+
False,
|
| 600 |
+
)
|
| 601 |
+
_, pkc = em.per_map_power_spectra_log(imgs, 25.0)
|
| 602 |
+
log_pg = np.log(pkc[:, valid] + 1e-30)
|
| 603 |
+
mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1)
|
| 604 |
+
scores.append(-mse / (2.0 * 0.25**2))
|
| 605 |
+
sc = np.concatenate(scores)
|
| 606 |
+
sc -= sc.max()
|
| 607 |
+
w = np.exp(sc).reshape(g, g)
|
| 608 |
+
w /= w.sum()
|
| 609 |
+
return w
|
| 610 |
+
|
| 611 |
+
Wmap = weights_full()
|
| 612 |
+
tom, ts8 = float(label_anchor_full[0]), float(label_anchor_full[1])
|
| 613 |
+
mom = float((Wmap * OM).sum())
|
| 614 |
+
ms8 = float((Wmap * S8).sum())
|
| 615 |
+
|
| 616 |
+
fig, ax = plt.subplots(figsize=(5.2, 4.6))
|
| 617 |
+
cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues")
|
| 618 |
+
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
|
| 619 |
+
ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true")
|
| 620 |
+
ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean")
|
| 621 |
+
ax.set_xlabel(r"$\Omega_m$")
|
| 622 |
+
ax.set_ylabel(r"$\sigma_8$")
|
| 623 |
+
ax.legend(fontsize=8)
|
| 624 |
+
ax.set_title(f"Surrogate posterior (test ix={ix}, ldim={lab_dim})", fontsize=10)
|
| 625 |
+
p = out_dir / f"posterior_surrogate_test_ix_{ix}_{tag}.png"
|
| 626 |
+
fig.savefig(p, dpi=160, bbox_inches="tight")
|
| 627 |
+
plt.close(fig)
|
| 628 |
+
print("Saved", p)
|
| 629 |
+
|
| 630 |
+
|
| 631 |
+
def main(argv: Sequence[str] | None = None) -> None:
|
| 632 |
+
p = argparse.ArgumentParser(description="DDPM-2 vs DDPM-6 comparison suite.")
|
| 633 |
+
p.add_argument(
|
| 634 |
+
"--output-dir",
|
| 635 |
+
type=Path,
|
| 636 |
+
default=MODELS_ROOT / "ddpm_comparison_out",
|
| 637 |
+
)
|
| 638 |
+
p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
|
| 639 |
+
p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
|
| 640 |
+
p.add_argument(
|
| 641 |
+
"--bundle-2param",
|
| 642 |
+
type=Path,
|
| 643 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
|
| 644 |
+
)
|
| 645 |
+
p.add_argument(
|
| 646 |
+
"--bundle-6param",
|
| 647 |
+
type=Path,
|
| 648 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
|
| 649 |
+
)
|
| 650 |
+
p.add_argument("--posterior-index", type=int, default=56)
|
| 651 |
+
p.add_argument("--lhs-n", type=int, default=50)
|
| 652 |
+
p.add_argument("--six-n-per-anchor", type=int, default=15)
|
| 653 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 654 |
+
p.add_argument("--seed", type=int, default=42)
|
| 655 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 656 |
+
p.add_argument("--slurm-6param", type=Path, default=DEFAULT_SLURM_6)
|
| 657 |
+
p.add_argument(
|
| 658 |
+
"--slurm-2param",
|
| 659 |
+
type=Path,
|
| 660 |
+
default=DEFAULT_DDPM2_TRAINING,
|
| 661 |
+
help="DDPM-2 train/val series: Slurm .out (parsed) or bundled ddpm_2param_training_loss.json.",
|
| 662 |
+
)
|
| 663 |
+
p.add_argument("--skip-lhs-r2", action="store_true", help="LHS R² plots are expensive; skip if set.")
|
| 664 |
+
p.add_argument("--n-random-triplets", type=int, default=4)
|
| 665 |
+
args = p.parse_args(list(argv) if argv is not None else None)
|
| 666 |
+
|
| 667 |
+
out_dir = Path(args.output_dir).resolve()
|
| 668 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 669 |
+
|
| 670 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 671 |
+
print("device:", device)
|
| 672 |
+
|
| 673 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 674 |
+
args2 = args.bundle_2param / "args.json"
|
| 675 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 676 |
+
args6 = args.bundle_6param / "args.json"
|
| 677 |
+
|
| 678 |
+
data2 = Path(args.data_2param)
|
| 679 |
+
data6 = Path(args.data_6param)
|
| 680 |
+
|
| 681 |
+
imgs6, lab6 = ec.load_split(data6, "test")
|
| 682 |
+
mean6, std6 = ec.load_label_stats(data6)
|
| 683 |
+
mean2, std2 = ec.load_label_stats(data2)
|
| 684 |
+
|
| 685 |
+
plot_training_overlay(out_dir, args.slurm_6param, args.slurm_2param)
|
| 686 |
+
|
| 687 |
+
imgs2, lab2 = ec.load_split(data2, "test")
|
| 688 |
+
|
| 689 |
+
print(">>> Loading DDPM-2...")
|
| 690 |
+
model2, cfg2 = load_model(args2, ck2, device)
|
| 691 |
+
print(">>> Loading DDPM-6...")
|
| 692 |
+
model6, cfg6 = load_model(args6, ck6, device)
|
| 693 |
+
|
| 694 |
+
print(">>> Random LHS + conditioned triplets...")
|
| 695 |
+
try:
|
| 696 |
+
run_random_theta_triplets(
|
| 697 |
+
out_dir,
|
| 698 |
+
imgs6,
|
| 699 |
+
lab6,
|
| 700 |
+
mean6,
|
| 701 |
+
std6,
|
| 702 |
+
mean2,
|
| 703 |
+
std2,
|
| 704 |
+
model2,
|
| 705 |
+
model6,
|
| 706 |
+
device=device,
|
| 707 |
+
ddim_steps=args.ddim_steps,
|
| 708 |
+
seed=args.seed,
|
| 709 |
+
n_pairs=args.n_random_triplets,
|
| 710 |
+
batch_size=args.batch_size,
|
| 711 |
+
)
|
| 712 |
+
run_conditioned_test_triplets(
|
| 713 |
+
out_dir,
|
| 714 |
+
imgs6,
|
| 715 |
+
lab6,
|
| 716 |
+
mean6,
|
| 717 |
+
std6,
|
| 718 |
+
mean2,
|
| 719 |
+
std2,
|
| 720 |
+
model2,
|
| 721 |
+
model6,
|
| 722 |
+
device=device,
|
| 723 |
+
ddim_steps=args.ddim_steps,
|
| 724 |
+
seed=args.seed,
|
| 725 |
+
n_pairs=args.n_random_triplets,
|
| 726 |
+
batch_size=args.batch_size,
|
| 727 |
+
)
|
| 728 |
+
except Exception as exc:
|
| 729 |
+
print("Triplet grids failed:", exc)
|
| 730 |
+
|
| 731 |
+
print(">>> Six-anchor overlays (combined + per-model)...")
|
| 732 |
+
try:
|
| 733 |
+
pk_six_triplet_combined(
|
| 734 |
+
out_dir,
|
| 735 |
+
imgs6,
|
| 736 |
+
lab6,
|
| 737 |
+
mean6,
|
| 738 |
+
std6,
|
| 739 |
+
mean2,
|
| 740 |
+
std2,
|
| 741 |
+
model2,
|
| 742 |
+
model6,
|
| 743 |
+
device=device,
|
| 744 |
+
ddim_steps=args.ddim_steps,
|
| 745 |
+
batch_size=args.batch_size,
|
| 746 |
+
n_per_set=args.six_n_per_anchor,
|
| 747 |
+
)
|
| 748 |
+
pk_pdf_six_sets(
|
| 749 |
+
out_dir,
|
| 750 |
+
"ddpm6_only",
|
| 751 |
+
imgs6,
|
| 752 |
+
lab6,
|
| 753 |
+
mean6,
|
| 754 |
+
std6,
|
| 755 |
+
model6,
|
| 756 |
+
device,
|
| 757 |
+
args.ddim_steps,
|
| 758 |
+
args.batch_size,
|
| 759 |
+
args.six_n_per_anchor,
|
| 760 |
+
)
|
| 761 |
+
pk_pdf_six_sets(
|
| 762 |
+
out_dir,
|
| 763 |
+
"ddpm2_only",
|
| 764 |
+
imgs2,
|
| 765 |
+
lab2,
|
| 766 |
+
mean2,
|
| 767 |
+
std2,
|
| 768 |
+
model2,
|
| 769 |
+
device,
|
| 770 |
+
args.ddim_steps,
|
| 771 |
+
args.batch_size,
|
| 772 |
+
args.six_n_per_anchor,
|
| 773 |
+
)
|
| 774 |
+
except Exception as exc:
|
| 775 |
+
print("P(k)/PDF six-anchor plots failed:", exc)
|
| 776 |
+
|
| 777 |
+
if not args.skip_lhs_r2:
|
| 778 |
+
print(">>> LHS R² (LHS-50 × 15 DDIM each — long)...")
|
| 779 |
+
try:
|
| 780 |
+
for label, imgs, labs, mn, sd, mdl in (
|
| 781 |
+
("ddpm2_lhs50", imgs2, lab2, mean2, std2, model2),
|
| 782 |
+
("ddpm6_lhs50", imgs6, lab6, mean6, std6, model6),
|
| 783 |
+
):
|
| 784 |
+
lhs_pts, r2_mu, r2_sig, lo_b, hi_b = compute_lhs_r2(
|
| 785 |
+
mdl,
|
| 786 |
+
imgs,
|
| 787 |
+
labs,
|
| 788 |
+
mn,
|
| 789 |
+
sd,
|
| 790 |
+
device,
|
| 791 |
+
args.lhs_n,
|
| 792 |
+
15,
|
| 793 |
+
args.batch_size,
|
| 794 |
+
25.0,
|
| 795 |
+
args.ddim_steps,
|
| 796 |
+
args.seed,
|
| 797 |
+
)
|
| 798 |
+
outp = out_dir / f"r2_cosmology_lhs{args.lhs_n}_{label}.png"
|
| 799 |
+
plot_r2_cosmology_figure(lhs_pts, r2_mu, r2_sig, lo_b, hi_b, outp, dpi=160)
|
| 800 |
+
print("Saved", outp)
|
| 801 |
+
np.savez(
|
| 802 |
+
out_dir / f"r2_lhs_data_{label}.npz",
|
| 803 |
+
lhs_pts=lhs_pts,
|
| 804 |
+
r2_mu_arr=r2_mu,
|
| 805 |
+
r2_sig_arr=r2_sig,
|
| 806 |
+
lo_b=lo_b,
|
| 807 |
+
hi_b=hi_b,
|
| 808 |
+
)
|
| 809 |
+
except Exception as exc:
|
| 810 |
+
print("LHS R² skipped:", exc)
|
| 811 |
+
else:
|
| 812 |
+
print("(Skipping LHS R².)")
|
| 813 |
+
|
| 814 |
+
print(">>> MLP P(k) parameter recovery...")
|
| 815 |
+
try:
|
| 816 |
+
mlp_recovery_dual(
|
| 817 |
+
out_dir, data2, imgs2[:40], lab2[:40], mean2, std2, model2, "ddpm2param", device, args.ddim_steps, args.seed
|
| 818 |
+
)
|
| 819 |
+
mlp_recovery_dual(
|
| 820 |
+
out_dir, data6, imgs6[:40], lab6[:40], mean6, std6, model6, "ddpm6param", device, args.ddim_steps, args.seed
|
| 821 |
+
)
|
| 822 |
+
except Exception as exc:
|
| 823 |
+
print("MLP recovery skipped:", exc)
|
| 824 |
+
|
| 825 |
+
print(f">>> Surrogate posteriors (test index {args.posterior_index})...")
|
| 826 |
+
try:
|
| 827 |
+
ix = int(args.posterior_index)
|
| 828 |
+
posterior_one_index(
|
| 829 |
+
out_dir, imgs6, lab6, mean6, std6, model6, cfg6, device, ix, "ddpm6", args.ddim_steps, 14, args.batch_size
|
| 830 |
+
)
|
| 831 |
+
posterior_one_index(
|
| 832 |
+
out_dir,
|
| 833 |
+
imgs2,
|
| 834 |
+
lab2,
|
| 835 |
+
mean2,
|
| 836 |
+
std2,
|
| 837 |
+
model2,
|
| 838 |
+
cfg2,
|
| 839 |
+
device,
|
| 840 |
+
ix,
|
| 841 |
+
"ddpm2",
|
| 842 |
+
args.ddim_steps,
|
| 843 |
+
14,
|
| 844 |
+
args.batch_size,
|
| 845 |
+
)
|
| 846 |
+
except Exception as exc:
|
| 847 |
+
print("Posterior panels skipped:", exc)
|
| 848 |
+
|
| 849 |
+
del model2, model6
|
| 850 |
+
free_torch()
|
| 851 |
+
print(f"Done. Outputs in {out_dir}")
|
| 852 |
+
|
| 853 |
+
|
| 854 |
+
if __name__ == "__main__":
|
| 855 |
+
main()
|
cross_model/scripts/compare_ddpm_training_curves.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Parse DDPM Slurm stdout or bundled JSON for Train/Val loss series."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Tuple
|
| 10 |
+
|
| 11 |
+
_ROW = re.compile(
|
| 12 |
+
r"Epoch\s+(?P<ep>\d+)/\d+\s+\|\s+Train:\s+(?P<tr>[\d.eE+-]+)\s+\|\s+Val:\s+(?P<va>[\d.eE+-]+)",
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_slurm_training_log(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
|
| 17 |
+
"""Return (epochs, train_losses, val_losses) parsed from Slurm *.out stdout."""
|
| 18 |
+
p = Path(path)
|
| 19 |
+
text = p.read_text(encoding="utf-8", errors="replace")
|
| 20 |
+
epochs, trains, vals = [], [], []
|
| 21 |
+
for m in _ROW.finditer(text):
|
| 22 |
+
epochs.append(int(m.group("ep")))
|
| 23 |
+
trains.append(float(m.group("tr")))
|
| 24 |
+
vals.append(float(m.group("va")))
|
| 25 |
+
return epochs, trains, vals
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_training_loss_json(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
|
| 29 |
+
"""Return (epochs, train_losses, val_losses) from a JSON export (keys: epochs, train, val)."""
|
| 30 |
+
p = Path(path)
|
| 31 |
+
raw = json.loads(p.read_text(encoding="utf-8"))
|
| 32 |
+
epochs = [int(e) for e in raw["epochs"]]
|
| 33 |
+
trains = [float(x) for x in raw["train"]]
|
| 34 |
+
vals = [float(x) for x in raw["val"]]
|
| 35 |
+
if not (len(epochs) == len(trains) == len(vals)):
|
| 36 |
+
raise ValueError(f"{p}: mismatched lengths in epochs/train/val")
|
| 37 |
+
return epochs, trains, vals
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_train_val_series(path: str | Path) -> Tuple[list[int], list[float], list[float]]:
|
| 41 |
+
"""Slurm *.out or *.json with the same semantic output as ``parse_slurm_training_log``."""
|
| 42 |
+
p = Path(path)
|
| 43 |
+
if p.suffix.lower() == ".json":
|
| 44 |
+
return load_training_loss_json(p)
|
| 45 |
+
return parse_slurm_training_log(p)
|
cross_model/scripts/ddpm_figure6_integration.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Figure 6 style (arXiv:2409.09101) helpers for DDPM surrogate posteriors — use with ddpm_posterior_six_anchors / run_ddpm_figure6_suite.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy as np
|
| 11 |
+
from matplotlib import gridspec
|
| 12 |
+
|
| 13 |
+
from figure6_2409_style import (
|
| 14 |
+
create_comparison_marginal_vs_profile,
|
| 15 |
+
create_figure6_style_plot,
|
| 16 |
+
)
|
| 17 |
+
from sigma_contour_utils import compute_sigma_levels
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def integrate_figure6_with_ddpm2(
|
| 21 |
+
Wmap: np.ndarray,
|
| 22 |
+
om_grid: np.ndarray,
|
| 23 |
+
s8_grid: np.ndarray,
|
| 24 |
+
true_om: float,
|
| 25 |
+
true_s8: float,
|
| 26 |
+
test_index: int,
|
| 27 |
+
output_dir: Path,
|
| 28 |
+
model_name: str = "DDPM-2",
|
| 29 |
+
) -> None:
|
| 30 |
+
"""Single-map Figure 6 style: marginal and profile PNGs."""
|
| 31 |
+
output_dir = Path(output_dir)
|
| 32 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
nm = model_name.replace(" ", "-").lower()
|
| 34 |
+
|
| 35 |
+
fig_marginal = create_figure6_style_plot(
|
| 36 |
+
Wmap,
|
| 37 |
+
om_grid,
|
| 38 |
+
s8_grid,
|
| 39 |
+
true_param1=true_om,
|
| 40 |
+
true_param2=true_s8,
|
| 41 |
+
param1_label=r"$\Omega_m$",
|
| 42 |
+
param2_label=r"$\sigma_8$",
|
| 43 |
+
title=f"{model_name} — Test ix={test_index} (Marginal)",
|
| 44 |
+
show_profile=False,
|
| 45 |
+
figsize=(10, 10),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
save_path_marginal = output_dir / f"fig6_style_{nm}_ix{test_index}_marginal.png"
|
| 49 |
+
fig_marginal.savefig(save_path_marginal, dpi=200, bbox_inches="tight")
|
| 50 |
+
plt.close(fig_marginal)
|
| 51 |
+
print(f"Saved: {save_path_marginal}")
|
| 52 |
+
|
| 53 |
+
fig_profile = create_figure6_style_plot(
|
| 54 |
+
Wmap,
|
| 55 |
+
om_grid,
|
| 56 |
+
s8_grid,
|
| 57 |
+
true_param1=true_om,
|
| 58 |
+
true_param2=true_s8,
|
| 59 |
+
param1_label=r"$\Omega_m$",
|
| 60 |
+
param2_label=r"$\sigma_8$",
|
| 61 |
+
title=f"{model_name} — Test ix={test_index} (Profile)",
|
| 62 |
+
show_profile=True,
|
| 63 |
+
figsize=(10, 10),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
save_path_profile = output_dir / f"fig6_style_{nm}_ix{test_index}_profile.png"
|
| 67 |
+
fig_profile.savefig(save_path_profile, dpi=200, bbox_inches="tight")
|
| 68 |
+
plt.close(fig_profile)
|
| 69 |
+
print(f"Saved: {save_path_profile}")
|
| 70 |
+
|
| 71 |
+
fig_cmp = create_comparison_marginal_vs_profile(
|
| 72 |
+
Wmap,
|
| 73 |
+
om_grid,
|
| 74 |
+
s8_grid,
|
| 75 |
+
true_param1=true_om,
|
| 76 |
+
true_param2=true_s8,
|
| 77 |
+
title=f"{model_name} marginal vs profile — ix={test_index}",
|
| 78 |
+
figsize=(11, 4.2),
|
| 79 |
+
)
|
| 80 |
+
cmp_path = output_dir / f"fig6_marg_vs_prof_{nm}_ix{test_index}.png"
|
| 81 |
+
fig_cmp.savefig(cmp_path, dpi=185, bbox_inches="tight")
|
| 82 |
+
plt.close(fig_cmp)
|
| 83 |
+
print(f"Saved: {cmp_path}")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def integrate_figure6_with_multi_anchor(
|
| 87 |
+
posteriors_list: list[np.ndarray],
|
| 88 |
+
om_grid: np.ndarray,
|
| 89 |
+
s8_grid: np.ndarray,
|
| 90 |
+
true_values_list: list[tuple[float, float]],
|
| 91 |
+
test_indices: list[int],
|
| 92 |
+
output_dir: Path,
|
| 93 |
+
model_name: str = "DDPM-2",
|
| 94 |
+
) -> None:
|
| 95 |
+
"""2×3 grid with Figure–6-ish 2D + top marginal."""
|
| 96 |
+
output_dir = Path(output_dir)
|
| 97 |
+
nm = model_name.replace(" ", "-").lower()
|
| 98 |
+
fig = plt.figure(figsize=(20, 14))
|
| 99 |
+
gs_o = gridspec.GridSpec(2, 3, figure=fig, hspace=0.33, wspace=0.32)
|
| 100 |
+
|
| 101 |
+
for idx, (posterior, true_vals, test_ix) in enumerate(
|
| 102 |
+
zip(posteriors_list, true_values_list, test_indices)
|
| 103 |
+
):
|
| 104 |
+
true_om, true_s8 = true_vals
|
| 105 |
+
row, col = divmod(idx, 3)
|
| 106 |
+
|
| 107 |
+
posterior_norm = posterior / posterior.sum()
|
| 108 |
+
sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
|
| 109 |
+
P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
|
| 110 |
+
|
| 111 |
+
gs_sub = gridspec.GridSpecFromSubplotSpec(
|
| 112 |
+
2,
|
| 113 |
+
2,
|
| 114 |
+
subplot_spec=gs_o[row, col],
|
| 115 |
+
width_ratios=[4, 1],
|
| 116 |
+
height_ratios=[1, 4],
|
| 117 |
+
hspace=0.06,
|
| 118 |
+
wspace=0.06,
|
| 119 |
+
)
|
| 120 |
+
ax_main = fig.add_subplot(gs_sub[1, 0])
|
| 121 |
+
ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
|
| 122 |
+
|
| 123 |
+
ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
|
| 124 |
+
if len(set(sigma_levels)) >= 1:
|
| 125 |
+
ax_main.contour(
|
| 126 |
+
P1,
|
| 127 |
+
P2,
|
| 128 |
+
posterior_norm,
|
| 129 |
+
levels=sigma_levels,
|
| 130 |
+
colors=["darkblue", "steelblue"],
|
| 131 |
+
linewidths=[2.0, 1.5],
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
ax_main.scatter(true_om, true_s8, s=100, c="red", marker="x", linewidths=2.5, zorder=10)
|
| 135 |
+
ax_main.set_xlim(om_grid[0], om_grid[-1])
|
| 136 |
+
ax_main.set_ylim(s8_grid[0], s8_grid[-1])
|
| 137 |
+
ax_main.set_xlabel(r"$\Omega_m$" if row == 1 else "", fontsize=11)
|
| 138 |
+
ax_main.set_ylabel(r"$\sigma_8$" if col == 0 else "", fontsize=11)
|
| 139 |
+
ax_main.set_title(f"Test ix={test_ix}", fontsize=11, pad=5)
|
| 140 |
+
ax_main.grid(True, alpha=0.2)
|
| 141 |
+
|
| 142 |
+
marginal_om = posterior_norm.sum(axis=1)
|
| 143 |
+
marginal_om /= marginal_om.sum() + 1e-30
|
| 144 |
+
ax_top.fill_between(
|
| 145 |
+
om_grid,
|
| 146 |
+
0.0,
|
| 147 |
+
marginal_om,
|
| 148 |
+
alpha=0.6,
|
| 149 |
+
color="steelblue",
|
| 150 |
+
edgecolor="steelblue",
|
| 151 |
+
)
|
| 152 |
+
ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2)
|
| 153 |
+
ax_top.set_xlim(om_grid[0], om_grid[-1])
|
| 154 |
+
ax_top.set_ylim(0, marginal_om.max() * 1.1)
|
| 155 |
+
ax_top.tick_params(labelbottom=False, labelsize=9)
|
| 156 |
+
ax_top.set_ylabel("$P(\\Omega_m)$", fontsize=9)
|
| 157 |
+
ax_top.grid(True, alpha=0.2)
|
| 158 |
+
|
| 159 |
+
ax_side = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
|
| 160 |
+
marginal_s8 = posterior_norm.sum(axis=0)
|
| 161 |
+
marginal_s8 /= marginal_s8.sum() + 1e-30
|
| 162 |
+
ax_side.fill_betweenx(s8_grid, 0.0, marginal_s8, alpha=0.6, color="steelblue", edgecolor="steelblue")
|
| 163 |
+
ax_side.axhline(true_s8, color="red", linestyle="--", linewidth=2)
|
| 164 |
+
ax_side.set_ylim(s8_grid[0], s8_grid[-1])
|
| 165 |
+
ax_side.set_xlim(0, marginal_s8.max() * 1.15)
|
| 166 |
+
ax_side.tick_params(labelleft=False)
|
| 167 |
+
|
| 168 |
+
fig.suptitle(
|
| 169 |
+
f"{model_name} — Figure 6 Style: Six Test Anchors",
|
| 170 |
+
fontsize=15,
|
| 171 |
+
y=0.995,
|
| 172 |
+
fontweight="bold",
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
save_path = output_dir / f"fig6_style_{nm}_all_anchors.png"
|
| 176 |
+
fig.savefig(save_path, dpi=200, bbox_inches="tight")
|
| 177 |
+
plt.close(fig)
|
| 178 |
+
print(f"Saved multi-anchor grid: {save_path}")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def integrate_figure6_model_comparison(
|
| 182 |
+
posteriors_dict: dict[str, np.ndarray],
|
| 183 |
+
om_grid: np.ndarray,
|
| 184 |
+
s8_grid: np.ndarray,
|
| 185 |
+
true_om: float,
|
| 186 |
+
true_s8: float,
|
| 187 |
+
test_index: int,
|
| 188 |
+
output_dir: Path,
|
| 189 |
+
) -> None:
|
| 190 |
+
"""Side-by-side model comparison panels."""
|
| 191 |
+
output_dir = Path(output_dir)
|
| 192 |
+
n_models = len(posteriors_dict)
|
| 193 |
+
fig = plt.figure(figsize=(8 * max(1, min(n_models, 4)), 8))
|
| 194 |
+
gs_outer = gridspec.GridSpec(1, n_models, figure=fig, wspace=0.32)
|
| 195 |
+
|
| 196 |
+
for idx, (model_name, posterior) in enumerate(posteriors_dict.items()):
|
| 197 |
+
gs_sub = gridspec.GridSpecFromSubplotSpec(
|
| 198 |
+
2,
|
| 199 |
+
2,
|
| 200 |
+
subplot_spec=gs_outer[0, idx],
|
| 201 |
+
width_ratios=[4, 1],
|
| 202 |
+
height_ratios=[1, 4],
|
| 203 |
+
hspace=0.06,
|
| 204 |
+
wspace=0.06,
|
| 205 |
+
)
|
| 206 |
+
posterior_norm = posterior / posterior.sum()
|
| 207 |
+
sigma_levels = compute_sigma_levels(posterior_norm, [0.683, 0.954])
|
| 208 |
+
P1, P2 = np.meshgrid(om_grid, s8_grid, indexing="ij")
|
| 209 |
+
|
| 210 |
+
ax_main = fig.add_subplot(gs_sub[1, 0])
|
| 211 |
+
ax_top = fig.add_subplot(gs_sub[0, 0], sharex=ax_main)
|
| 212 |
+
|
| 213 |
+
ax_main.contourf(P1, P2, posterior_norm, levels=20, cmap="Blues", alpha=0.85)
|
| 214 |
+
if len(set(sigma_levels)) >= 1:
|
| 215 |
+
ax_main.contour(
|
| 216 |
+
P1,
|
| 217 |
+
P2,
|
| 218 |
+
posterior_norm,
|
| 219 |
+
levels=sigma_levels,
|
| 220 |
+
colors=["darkblue", "steelblue"],
|
| 221 |
+
linewidths=[2.5, 2.0],
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
ax_main.scatter(true_om, true_s8, s=120, c="red", marker="x", linewidths=3, zorder=10)
|
| 225 |
+
ax_main.set_xlabel(r"$\Omega_m$", fontsize=13)
|
| 226 |
+
ax_main.set_ylabel(r"$\sigma_8$", fontsize=13)
|
| 227 |
+
ax_main.set_title(model_name, fontsize=13, pad=10, fontweight="bold")
|
| 228 |
+
ax_main.grid(True, alpha=0.3)
|
| 229 |
+
ax_main.set_xlim(om_grid[0], om_grid[-1])
|
| 230 |
+
ax_main.set_ylim(s8_grid[0], s8_grid[-1])
|
| 231 |
+
|
| 232 |
+
marginal = posterior_norm.sum(axis=1)
|
| 233 |
+
marginal /= marginal.sum() + 1e-30
|
| 234 |
+
ax_top.fill_between(om_grid, 0.0, marginal, alpha=0.6, color="steelblue")
|
| 235 |
+
ax_top.axvline(true_om, color="red", linestyle="--", linewidth=2.5)
|
| 236 |
+
ax_top.set_xlim(om_grid[0], om_grid[-1])
|
| 237 |
+
ax_top.set_ylim(0, marginal.max() * 1.12)
|
| 238 |
+
ax_top.tick_params(labelbottom=False)
|
| 239 |
+
ax_top.grid(True, alpha=0.25)
|
| 240 |
+
|
| 241 |
+
ax_sb = fig.add_subplot(gs_sub[1, 1], sharey=ax_main)
|
| 242 |
+
marginal_s = posterior_norm.sum(axis=0)
|
| 243 |
+
marginal_s /= marginal_s.sum() + 1e-30
|
| 244 |
+
ax_sb.fill_betweenx(s8_grid, 0.0, marginal_s, alpha=0.6, color="steelblue")
|
| 245 |
+
ax_sb.axhline(true_s8, color="red", linestyle="--", linewidth=2.5)
|
| 246 |
+
ax_sb.set_ylim(s8_grid[0], s8_grid[-1])
|
| 247 |
+
|
| 248 |
+
fig.suptitle(
|
| 249 |
+
f"Model Comparison (Figure 6 Style) — Test ix={test_index}",
|
| 250 |
+
fontsize=15,
|
| 251 |
+
y=0.995,
|
| 252 |
+
fontweight="bold",
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
save_path = output_dir / f"fig6_style_model_comparison_ix{test_index}.png"
|
| 256 |
+
fig.savefig(save_path, dpi=200, bbox_inches="tight")
|
| 257 |
+
plt.close(fig)
|
| 258 |
+
print(f"Saved model comparison: {save_path}")
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def print_integration_guide() -> None:
|
| 262 |
+
example_integration = """
|
| 263 |
+
# Add imports next to posterior code:
|
| 264 |
+
from figure6_2409_style import create_figure6_style_plot
|
| 265 |
+
from ddpm_figure6_integration import (
|
| 266 |
+
integrate_figure6_with_ddpm2,
|
| 267 |
+
integrate_figure6_with_multi_anchor,
|
| 268 |
+
integrate_figure6_model_comparison,
|
| 269 |
+
)
|
| 270 |
+
"""
|
| 271 |
+
print(example_integration.strip())
|
cross_model/scripts/ddpm_posterior_six_anchors.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Surrogate P(k) likelihood posteriors on ($\\Omega_m$, $\\sigma_8$) for six test anchors.
|
| 4 |
+
|
| 5 |
+
For each model:
|
| 6 |
+
• DDPM-2 — standard 2D marginal: sweep ($\\Omega_m$, $\\sigma_8$) while only two labels exist.
|
| 7 |
+
• DDPM-6 — same 2D sweep, but astrophysical / extra dimensions 2–5 are fixed in two cases:
|
| 8 |
+
- **extra_lower**: each of dims 2–5 fixed to the LHS **minimum** (from training labels)
|
| 9 |
+
- **extra_upper**: each fixed to the LHS **maximum**
|
| 10 |
+
|
| 11 |
+
The observed HI map is always the CAMELS test MAP at that anchor index.
|
| 12 |
+
|
| 13 |
+
This does not import ``compare_ddpm_models.py``; it only shares the same conventions and paths.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import gc
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Dict, Tuple
|
| 23 |
+
|
| 24 |
+
import matplotlib
|
| 25 |
+
|
| 26 |
+
matplotlib.use("Agg")
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
MODELS_ROOT = Path(__file__).resolve().parents[1]
|
| 32 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 33 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 34 |
+
sys.path.insert(0, str(CODE_6))
|
| 35 |
+
|
| 36 |
+
import evaluate_conditional as ec # noqa: E402
|
| 37 |
+
import eval_model as em # noqa: E402
|
| 38 |
+
from figure9_posterior import build_cosmo_grid, log_pk_observed # noqa: E402
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _fmt_title(lab: np.ndarray) -> str:
|
| 42 |
+
t = np.asarray(lab, dtype=float).ravel()
|
| 43 |
+
if t.size <= 2:
|
| 44 |
+
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f}"
|
| 45 |
+
tail = ", ".join(f"{float(v):.3g}" for v in t[2:])
|
| 46 |
+
return rf"$\Omega_m$={t[0]:.3f}, $\sigma_8$={t[1]:.3f} | " + tail
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _train_label_path(data_dir: Path) -> Path:
|
| 50 |
+
for name in ("train_labels_LH.npy", "train_labels_LH_2.npy"):
|
| 51 |
+
p = data_dir / name
|
| 52 |
+
if p.is_file():
|
| 53 |
+
return p
|
| 54 |
+
raise FileNotFoundError(f"No train_labels_LH*.npy under {data_dir}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def tail_lhs_bounds(data_dir: Path) -> Tuple[np.ndarray, np.ndarray]:
|
| 58 |
+
"""Min/max over training LHS for label dimensions 2 … 5 (indices 2–5)."""
|
| 59 |
+
L = np.load(_train_label_path(data_dir))
|
| 60 |
+
if L.shape[1] < 6:
|
| 61 |
+
raise ValueError(f"Expected ≥6 label columns, got {L.shape}")
|
| 62 |
+
lo = L[:, 2:6].min(axis=0).astype(np.float32)
|
| 63 |
+
hi = L[:, 2:6].max(axis=0).astype(np.float32)
|
| 64 |
+
return lo, hi
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def posterior_weights(
|
| 68 |
+
obs: np.ndarray,
|
| 69 |
+
full: np.ndarray,
|
| 70 |
+
om_ax: np.ndarray,
|
| 71 |
+
s8_ax: np.ndarray,
|
| 72 |
+
lab_mean: np.ndarray,
|
| 73 |
+
lab_std: np.ndarray,
|
| 74 |
+
normalize: bool,
|
| 75 |
+
model: torch.nn.Module,
|
| 76 |
+
*,
|
| 77 |
+
H: int,
|
| 78 |
+
W: int,
|
| 79 |
+
device: torch.device,
|
| 80 |
+
grid: int,
|
| 81 |
+
batch_sz: int,
|
| 82 |
+
ddim_steps: int,
|
| 83 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 84 |
+
"""
|
| 85 |
+
Returns (Wmap, OM, S8) with Wmap shaped (grid, grid); OM, S8 meshgrids (indexing='ij').
|
| 86 |
+
full: (ngrid, label_dim) rows on the (Ωm, σ8) grid plus any fixed tail dims.
|
| 87 |
+
"""
|
| 88 |
+
ngrid = full.shape[0]
|
| 89 |
+
g = int(round(np.sqrt(ngrid)))
|
| 90 |
+
if g * g != ngrid:
|
| 91 |
+
raise ValueError(f"Expected square grid, got ngrid={ngrid}")
|
| 92 |
+
|
| 93 |
+
npix = int(obs.shape[-1])
|
| 94 |
+
dl = 25.0 / npix
|
| 95 |
+
dk, _ = ec.PowerSpectrum(em.images01_to_log_nhi(obs), N=npix, dl=dl)
|
| 96 |
+
valid = dk > 0
|
| 97 |
+
log_pd = log_pk_observed(obs, 25.0, dk)
|
| 98 |
+
|
| 99 |
+
def weights_full() -> np.ndarray:
|
| 100 |
+
scores = []
|
| 101 |
+
for j0 in range(0, ngrid, batch_sz):
|
| 102 |
+
chunk = full[j0 : j0 + batch_sz]
|
| 103 |
+
imgs = em.sample_batch(
|
| 104 |
+
model,
|
| 105 |
+
chunk,
|
| 106 |
+
lab_mean,
|
| 107 |
+
lab_std,
|
| 108 |
+
normalize,
|
| 109 |
+
H,
|
| 110 |
+
W,
|
| 111 |
+
device,
|
| 112 |
+
ddim_steps,
|
| 113 |
+
False,
|
| 114 |
+
)
|
| 115 |
+
_, pkc = em.per_map_power_spectra_log(imgs, 25.0)
|
| 116 |
+
log_pg = np.log(pkc[:, valid] + 1e-30)
|
| 117 |
+
mse = np.mean((log_pd[np.newaxis, :] - log_pg) ** 2, axis=1)
|
| 118 |
+
scores.append(-mse / (2.0 * 0.25**2))
|
| 119 |
+
sc = np.concatenate(scores)
|
| 120 |
+
sc -= sc.max()
|
| 121 |
+
w = np.exp(sc).reshape(g, g)
|
| 122 |
+
w /= w.sum()
|
| 123 |
+
return w
|
| 124 |
+
|
| 125 |
+
Wmap = weights_full()
|
| 126 |
+
OM, S8 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 127 |
+
return Wmap, OM, S8
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def build_full_grid_2d(
|
| 131 |
+
labels_split: np.ndarray,
|
| 132 |
+
grid: int,
|
| 133 |
+
tail: np.ndarray | None,
|
| 134 |
+
lab_dim: int,
|
| 135 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 136 |
+
"""
|
| 137 |
+
If tail is None (2-param model path): full has shape (grid^2, 2).
|
| 138 |
+
Else tail shape (4,): dims 2–5 filled with constants; dims 0,1 swept.
|
| 139 |
+
"""
|
| 140 |
+
lo0 = float(labels_split[:, 0].min())
|
| 141 |
+
hi0 = float(labels_split[:, 0].max())
|
| 142 |
+
lo1 = float(labels_split[:, 1].min())
|
| 143 |
+
hi1 = float(labels_split[:, 1].max())
|
| 144 |
+
pad0 = 0.02 * (hi0 - lo0 + 1e-12)
|
| 145 |
+
pad1 = 0.02 * (hi1 - lo1 + 1e-12)
|
| 146 |
+
om_ax, s8_ax, OG, SG, grid2 = build_cosmo_grid(
|
| 147 |
+
grid, lo0 - pad0, hi0 + pad0, lo1 - pad1, hi1 + pad1
|
| 148 |
+
)
|
| 149 |
+
ngrid = grid2.shape[0]
|
| 150 |
+
if lab_dim == 2 and tail is not None:
|
| 151 |
+
raise ValueError("lab_dim==2 implies no extra tail.")
|
| 152 |
+
out = np.zeros((ngrid, lab_dim), dtype=np.float32)
|
| 153 |
+
out[:, 0] = grid2[:, 0].astype(np.float32)
|
| 154 |
+
out[:, 1] = grid2[:, 1].astype(np.float32)
|
| 155 |
+
if tail is not None:
|
| 156 |
+
assert tail.shape == (4,)
|
| 157 |
+
out[:, 2:6] = tail[np.newaxis, :]
|
| 158 |
+
return out.astype(np.float32), om_ax, s8_ax
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def plot_posterior_panel(
|
| 162 |
+
ax,
|
| 163 |
+
Wmap: np.ndarray,
|
| 164 |
+
OM: np.ndarray,
|
| 165 |
+
S8: np.ndarray,
|
| 166 |
+
tom: float,
|
| 167 |
+
ts8: float,
|
| 168 |
+
title: str,
|
| 169 |
+
*,
|
| 170 |
+
suptext: str | None = None,
|
| 171 |
+
) -> None:
|
| 172 |
+
mom = float((Wmap * OM).sum())
|
| 173 |
+
ms8 = float((Wmap * S8).sum())
|
| 174 |
+
cf = ax.contourf(OM, S8, Wmap, levels=12, cmap="Blues")
|
| 175 |
+
plt.colorbar(cf, ax=ax, fraction=0.046, pad=0.04)
|
| 176 |
+
ax.scatter(tom, ts8, s=55, c="r", marker="x", zorder=6, label="true")
|
| 177 |
+
ax.scatter(mom, ms8, s=60, c="k", marker="+", zorder=6, label="post. mean")
|
| 178 |
+
ax.set_xlabel(r"$\Omega_m$")
|
| 179 |
+
ax.set_ylabel(r"$\sigma_8$")
|
| 180 |
+
ax.legend(fontsize=7)
|
| 181 |
+
ax.set_title(title, fontsize=8)
|
| 182 |
+
if suptext:
|
| 183 |
+
ax.text(0.02, 0.98, suptext, transform=ax.transAxes, fontsize=7, va="top", color="#333")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def run_ddpm2_panels(
|
| 187 |
+
out_dir: Path,
|
| 188 |
+
images: np.ndarray,
|
| 189 |
+
labels: np.ndarray,
|
| 190 |
+
mean: np.ndarray,
|
| 191 |
+
std: np.ndarray,
|
| 192 |
+
cfg: Dict,
|
| 193 |
+
model: torch.nn.Module,
|
| 194 |
+
device: torch.device,
|
| 195 |
+
anchor_ix: np.ndarray,
|
| 196 |
+
grid: int,
|
| 197 |
+
ddim_steps: int,
|
| 198 |
+
batch_sz: int,
|
| 199 |
+
) -> None:
|
| 200 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 201 |
+
H, W = int(images.shape[-2]), int(images.shape[-1])
|
| 202 |
+
fig, axes = plt.subplots(2, 3, figsize=(14, 9), squeeze=False)
|
| 203 |
+
for k, ix in enumerate(anchor_ix.ravel()):
|
| 204 |
+
r, c = divmod(k, 3)
|
| 205 |
+
ax = axes[r, c]
|
| 206 |
+
obs = images[ix]
|
| 207 |
+
lab_t = labels[ix].astype(np.float32)
|
| 208 |
+
full, om_ax, s8_ax = build_full_grid_2d(labels, grid, tail=None, lab_dim=2)
|
| 209 |
+
Wmap, OM, S8 = posterior_weights(
|
| 210 |
+
obs,
|
| 211 |
+
full,
|
| 212 |
+
om_ax,
|
| 213 |
+
s8_ax,
|
| 214 |
+
mean,
|
| 215 |
+
std,
|
| 216 |
+
normalize,
|
| 217 |
+
model,
|
| 218 |
+
H=H,
|
| 219 |
+
W=W,
|
| 220 |
+
device=device,
|
| 221 |
+
grid=grid,
|
| 222 |
+
batch_sz=batch_sz,
|
| 223 |
+
ddim_steps=ddim_steps,
|
| 224 |
+
)
|
| 225 |
+
tom, ts8 = float(lab_t[0]), float(lab_t[1])
|
| 226 |
+
plot_posterior_panel(
|
| 227 |
+
ax,
|
| 228 |
+
Wmap,
|
| 229 |
+
OM,
|
| 230 |
+
S8,
|
| 231 |
+
tom,
|
| 232 |
+
ts8,
|
| 233 |
+
f"test ix={ix}\n{_fmt_title(lab_t)}",
|
| 234 |
+
)
|
| 235 |
+
plt.suptitle(
|
| 236 |
+
r"DDPM-2 surrogate posterior on $(\Omega_m,\,\sigma_8)$ — six CAMELS anchors",
|
| 237 |
+
fontsize=11,
|
| 238 |
+
y=0.995,
|
| 239 |
+
)
|
| 240 |
+
plt.tight_layout(rect=(0, 0, 1, 0.97))
|
| 241 |
+
p = out_dir / "posterior_six_anchors_ddpm2.png"
|
| 242 |
+
fig.savefig(p, dpi=170, bbox_inches="tight")
|
| 243 |
+
plt.close(fig)
|
| 244 |
+
print("Saved", p)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def run_ddpm6_case(
|
| 248 |
+
out_dir: Path,
|
| 249 |
+
*,
|
| 250 |
+
suffix: str,
|
| 251 |
+
tail_fixed: np.ndarray,
|
| 252 |
+
tail_name: str,
|
| 253 |
+
images: np.ndarray,
|
| 254 |
+
labels: np.ndarray,
|
| 255 |
+
mean: np.ndarray,
|
| 256 |
+
std: np.ndarray,
|
| 257 |
+
cfg: Dict,
|
| 258 |
+
model: torch.nn.Module,
|
| 259 |
+
device: torch.device,
|
| 260 |
+
anchor_ix: np.ndarray,
|
| 261 |
+
grid: int,
|
| 262 |
+
ddim_steps: int,
|
| 263 |
+
batch_sz: int,
|
| 264 |
+
) -> None:
|
| 265 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 266 |
+
H, W = int(images.shape[-2]), int(images.shape[-1])
|
| 267 |
+
fig, axes = plt.subplots(2, 3, figsize=(14, 9), squeeze=False)
|
| 268 |
+
for k, ix in enumerate(anchor_ix.ravel()):
|
| 269 |
+
r, c = divmod(k, 3)
|
| 270 |
+
ax = axes[r, c]
|
| 271 |
+
obs = images[ix]
|
| 272 |
+
lab_t = labels[ix].astype(np.float32)
|
| 273 |
+
full, om_ax, s8_ax = build_full_grid_2d(labels, grid, tail=tail_fixed, lab_dim=6)
|
| 274 |
+
Wmap, OM, S8 = posterior_weights(
|
| 275 |
+
obs,
|
| 276 |
+
full,
|
| 277 |
+
om_ax,
|
| 278 |
+
s8_ax,
|
| 279 |
+
mean,
|
| 280 |
+
std,
|
| 281 |
+
normalize,
|
| 282 |
+
model,
|
| 283 |
+
H=H,
|
| 284 |
+
W=W,
|
| 285 |
+
device=device,
|
| 286 |
+
grid=grid,
|
| 287 |
+
batch_sz=batch_sz,
|
| 288 |
+
ddim_steps=ddim_steps,
|
| 289 |
+
)
|
| 290 |
+
tom, ts8 = float(lab_t[0]), float(lab_t[1])
|
| 291 |
+
plot_posterior_panel(
|
| 292 |
+
ax,
|
| 293 |
+
Wmap,
|
| 294 |
+
OM,
|
| 295 |
+
S8,
|
| 296 |
+
tom,
|
| 297 |
+
ts8,
|
| 298 |
+
f"test ix={ix}",
|
| 299 |
+
suptext=tail_name,
|
| 300 |
+
)
|
| 301 |
+
plt.suptitle(
|
| 302 |
+
r"DDPM-6 — $(\Omega_m,\,\sigma_8)$ sweep; dims 2–5 fixed (" + tail_name + ")",
|
| 303 |
+
fontsize=11,
|
| 304 |
+
y=0.995,
|
| 305 |
+
)
|
| 306 |
+
plt.tight_layout(rect=(0, 0, 1, 0.96))
|
| 307 |
+
p = out_dir / f"posterior_six_anchors_ddpm6_{suffix}.png"
|
| 308 |
+
fig.savefig(p, dpi=170, bbox_inches="tight")
|
| 309 |
+
plt.close(fig)
|
| 310 |
+
print("Saved", p)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def load_model(bundle_args: Path, ckpt: Path, device: torch.device):
|
| 314 |
+
cfg = ec.load_training_config(str(bundle_args))
|
| 315 |
+
model = ec.build_model(cfg, device)
|
| 316 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 317 |
+
model.eval()
|
| 318 |
+
return model, cfg
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def main() -> None:
|
| 322 |
+
p = argparse.ArgumentParser(
|
| 323 |
+
description="Six-anchor surrogate posteriors: DDPM-2 and DDPM-6 (extra dims min vs max)."
|
| 324 |
+
)
|
| 325 |
+
p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_posterior_six_anchors_out")
|
| 326 |
+
p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
|
| 327 |
+
p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
|
| 328 |
+
p.add_argument(
|
| 329 |
+
"--bundle-2param",
|
| 330 |
+
type=Path,
|
| 331 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
|
| 332 |
+
)
|
| 333 |
+
p.add_argument(
|
| 334 |
+
"--bundle-6param",
|
| 335 |
+
type=Path,
|
| 336 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
|
| 337 |
+
)
|
| 338 |
+
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
|
| 339 |
+
p.add_argument("--grid", type=int, default=14, help="Grid points per Ωm–σ8 axis.")
|
| 340 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 341 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 342 |
+
p.add_argument(
|
| 343 |
+
"--ddpm2-only",
|
| 344 |
+
action="store_true",
|
| 345 |
+
help="Only compute DDPM-2 figure (skip loading DDPM-6).",
|
| 346 |
+
)
|
| 347 |
+
p.add_argument(
|
| 348 |
+
"--ddpm6-only",
|
| 349 |
+
action="store_true",
|
| 350 |
+
help="Only compute DDPM-6 figures (skip loading DDPM-2).",
|
| 351 |
+
)
|
| 352 |
+
args = p.parse_args()
|
| 353 |
+
|
| 354 |
+
out_dir = Path(args.output_dir).resolve()
|
| 355 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 356 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 357 |
+
print("device:", device)
|
| 358 |
+
|
| 359 |
+
data2 = Path(args.data_2param)
|
| 360 |
+
data6 = Path(args.data_6param)
|
| 361 |
+
|
| 362 |
+
imgs2, lab2 = ec.load_split(data2, args.split)
|
| 363 |
+
imgs6, lab6 = ec.load_split(data6, args.split)
|
| 364 |
+
|
| 365 |
+
n = min(len(lab2), len(lab6))
|
| 366 |
+
anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
|
| 367 |
+
|
| 368 |
+
low_tail, hi_tail = tail_lhs_bounds(data6)
|
| 369 |
+
print("LHS tails (dims 2–5): min", low_tail, "max", hi_tail)
|
| 370 |
+
|
| 371 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 372 |
+
args_json_2 = args.bundle_2param / "args.json"
|
| 373 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 374 |
+
args_json_6 = args.bundle_6param / "args.json"
|
| 375 |
+
|
| 376 |
+
mean2, std2 = ec.load_label_stats(data2)
|
| 377 |
+
mean6, std6 = ec.load_label_stats(data6)
|
| 378 |
+
|
| 379 |
+
if args.ddpm6_only and args.ddpm2_only:
|
| 380 |
+
raise SystemExit("Use at most one of --ddpm2-only / --ddpm6-only.")
|
| 381 |
+
|
| 382 |
+
if not args.ddpm6_only:
|
| 383 |
+
print(">>> DDPM-2 (six anchors)...")
|
| 384 |
+
model2, cfg2 = load_model(args_json_2, ck2, device)
|
| 385 |
+
run_ddpm2_panels(
|
| 386 |
+
out_dir,
|
| 387 |
+
imgs2,
|
| 388 |
+
lab2,
|
| 389 |
+
mean2,
|
| 390 |
+
std2,
|
| 391 |
+
cfg2,
|
| 392 |
+
model2,
|
| 393 |
+
device,
|
| 394 |
+
anchor_ix,
|
| 395 |
+
args.grid,
|
| 396 |
+
args.ddim_steps,
|
| 397 |
+
args.batch_size,
|
| 398 |
+
)
|
| 399 |
+
del model2
|
| 400 |
+
gc.collect()
|
| 401 |
+
if torch.cuda.is_available():
|
| 402 |
+
torch.cuda.empty_cache()
|
| 403 |
+
|
| 404 |
+
if not args.ddpm2_only:
|
| 405 |
+
print(">>> DDPM-6 — extra dims at LHS minima (six anchors)...")
|
| 406 |
+
model6, cfg6 = load_model(args_json_6, ck6, device)
|
| 407 |
+
run_ddpm6_case(
|
| 408 |
+
out_dir,
|
| 409 |
+
suffix="extra_lower",
|
| 410 |
+
tail_fixed=low_tail,
|
| 411 |
+
tail_name="min",
|
| 412 |
+
images=imgs6,
|
| 413 |
+
labels=lab6,
|
| 414 |
+
mean=mean6,
|
| 415 |
+
std=std6,
|
| 416 |
+
cfg=cfg6,
|
| 417 |
+
model=model6,
|
| 418 |
+
device=device,
|
| 419 |
+
anchor_ix=anchor_ix,
|
| 420 |
+
grid=args.grid,
|
| 421 |
+
ddim_steps=args.ddim_steps,
|
| 422 |
+
batch_sz=args.batch_size,
|
| 423 |
+
)
|
| 424 |
+
print(">>> DDPM-6 — extra dims at LHS maxima (six anchors)...")
|
| 425 |
+
run_ddpm6_case(
|
| 426 |
+
out_dir,
|
| 427 |
+
suffix="extra_upper",
|
| 428 |
+
tail_fixed=hi_tail,
|
| 429 |
+
tail_name="max",
|
| 430 |
+
images=imgs6,
|
| 431 |
+
labels=lab6,
|
| 432 |
+
mean=mean6,
|
| 433 |
+
std=std6,
|
| 434 |
+
cfg=cfg6,
|
| 435 |
+
model=model6,
|
| 436 |
+
device=device,
|
| 437 |
+
anchor_ix=anchor_ix,
|
| 438 |
+
grid=args.grid,
|
| 439 |
+
ddim_steps=args.ddim_steps,
|
| 440 |
+
batch_sz=args.batch_size,
|
| 441 |
+
)
|
| 442 |
+
del model6
|
| 443 |
+
gc.collect()
|
| 444 |
+
if torch.cuda.is_available():
|
| 445 |
+
torch.cuda.empty_cache()
|
| 446 |
+
|
| 447 |
+
print(f"Done. Outputs in {out_dir}")
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if __name__ == "__main__":
|
| 451 |
+
main()
|
cross_model/scripts/ddpm_triangle_integration.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Surrogate posterior on $(\\Omega_m, \\sigma_8)$ → triangle/MCMC-style chains for one test map.
|
| 4 |
+
|
| 5 |
+
Loads the same surrogate likelihood used in ``ddpm_posterior_six_anchors``, resamples discrete
|
| 6 |
+
posterior masses to ``--n-hist`` correlated $(\\Omega_m,\\sigma_8)$ pairs, and writes ``.npz``.
|
| 7 |
+
|
| 8 |
+
DDPM-2: sweeps $(\\Omega_m,\\sigma_8)$.
|
| 9 |
+
DDPM-6: dims 2–5 fixed per ``--six-tail-mode`` (``truth`` uses the test-map labels 2–5; ``min``/``max``
|
| 10 |
+
use LHS extrema from training labels).
|
| 11 |
+
|
| 12 |
+
If you replace this file with a copy from your machine (Downloads), keep argparse compatible or wrap it.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import sys
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
_SCRIPTS = Path(__file__).resolve().parent
|
| 25 |
+
if str(_SCRIPTS) not in sys.path:
|
| 26 |
+
sys.path.insert(0, str(_SCRIPTS))
|
| 27 |
+
|
| 28 |
+
import ddpm_posterior_six_anchors as dps # noqa: E402
|
| 29 |
+
|
| 30 |
+
MODELS_ROOT = Path(__file__).resolve().parents[1]
|
| 31 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 32 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 33 |
+
sys.path.insert(0, str(CODE_6.resolve()))
|
| 34 |
+
|
| 35 |
+
import evaluate_conditional as ec # noqa: E402
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _tail_vec(
|
| 39 |
+
mode: str,
|
| 40 |
+
lab_full: np.ndarray,
|
| 41 |
+
data6: Path,
|
| 42 |
+
) -> np.ndarray | None:
|
| 43 |
+
if lab_full.size <= 2:
|
| 44 |
+
return None
|
| 45 |
+
if mode == "truth":
|
| 46 |
+
return lab_full[2:6].astype(np.float32)
|
| 47 |
+
low, hi = dps.tail_lhs_bounds(data6)
|
| 48 |
+
if mode == "min":
|
| 49 |
+
return low
|
| 50 |
+
if mode == "max":
|
| 51 |
+
return hi
|
| 52 |
+
raise ValueError("six-tail-mode must be truth|min|max")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def main() -> None:
|
| 56 |
+
p = argparse.ArgumentParser(description="DDPM surrogate posterior → resampled Ωm σ8 chains (.npz).")
|
| 57 |
+
p.add_argument(
|
| 58 |
+
"--label-dim",
|
| 59 |
+
type=int,
|
| 60 |
+
choices=[2, 6],
|
| 61 |
+
required=True,
|
| 62 |
+
help="Which model to use.",
|
| 63 |
+
)
|
| 64 |
+
p.add_argument(
|
| 65 |
+
"--bundle",
|
| 66 |
+
type=Path,
|
| 67 |
+
default=None,
|
| 68 |
+
help="Checkpoint bundle dir with args.json (default: notebook_model_weights/<2|6>).",
|
| 69 |
+
)
|
| 70 |
+
p.add_argument(
|
| 71 |
+
"--checkpoint-name",
|
| 72 |
+
type=str,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Checkpoint file under bundle (defaults: DDPM2 epoch200, DDPM6 best_model).",
|
| 75 |
+
)
|
| 76 |
+
p.add_argument(
|
| 77 |
+
"--data-dir",
|
| 78 |
+
type=Path,
|
| 79 |
+
default=None,
|
| 80 |
+
help="LH data dir matching label_dim (default: params_2 vs params_6).",
|
| 81 |
+
)
|
| 82 |
+
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
|
| 83 |
+
p.add_argument("--test-index", type=int, default=56, help="Index into split for CAMELS observation.")
|
| 84 |
+
p.add_argument("--grid", type=int, default=14)
|
| 85 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 86 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 87 |
+
p.add_argument(
|
| 88 |
+
"--n-hist",
|
| 89 |
+
type=int,
|
| 90 |
+
default=10_000,
|
| 91 |
+
help="Resampled posterior pairs (with replacement).",
|
| 92 |
+
)
|
| 93 |
+
p.add_argument(
|
| 94 |
+
"--six-tail-mode",
|
| 95 |
+
type=str,
|
| 96 |
+
default="truth",
|
| 97 |
+
choices=["truth", "min", "max"],
|
| 98 |
+
help="Applies only to label_dim==6 — how dims 2–5 are fixed.",
|
| 99 |
+
)
|
| 100 |
+
p.add_argument(
|
| 101 |
+
"--output",
|
| 102 |
+
"-o",
|
| 103 |
+
type=Path,
|
| 104 |
+
required=True,
|
| 105 |
+
help="Output .npz path.",
|
| 106 |
+
)
|
| 107 |
+
p.add_argument("--seed", type=int, default=42)
|
| 108 |
+
args = p.parse_args()
|
| 109 |
+
|
| 110 |
+
ld = args.label_dim
|
| 111 |
+
if ld == 2:
|
| 112 |
+
data_dir = args.data_dir or Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2")
|
| 113 |
+
bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "2param_epoch200"
|
| 114 |
+
ck_name = args.checkpoint_name or "checkpoint_epoch_200.pt"
|
| 115 |
+
else:
|
| 116 |
+
data_dir = args.data_dir or Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6")
|
| 117 |
+
bundle = args.bundle or MODELS_ROOT / "notebook_model_weights" / "6param_best"
|
| 118 |
+
ck_name = args.checkpoint_name or "best_model.pt"
|
| 119 |
+
|
| 120 |
+
rng = np.random.default_rng(args.seed)
|
| 121 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 122 |
+
imgs, labs = ec.load_split(data_dir, args.split)
|
| 123 |
+
ix = int(args.test_index)
|
| 124 |
+
if not (0 <= ix < len(labs)):
|
| 125 |
+
raise SystemExit(f"test-index {ix} out of range for split ({len(labs)} rows)")
|
| 126 |
+
lab_t = labs[ix].astype(np.float64)
|
| 127 |
+
obs = imgs[ix]
|
| 128 |
+
ckpt = bundle / ck_name
|
| 129 |
+
args_json = bundle / "args.json"
|
| 130 |
+
mean, std = ec.load_label_stats(data_dir)
|
| 131 |
+
|
| 132 |
+
tail = None
|
| 133 |
+
if ld == 6:
|
| 134 |
+
lab6 = labs[ix].astype(np.float64)
|
| 135 |
+
if lab6.shape[0] != 6:
|
| 136 |
+
raise SystemExit("--label-dim 6 requires labels with 6 columns in data-dir")
|
| 137 |
+
tail = _tail_vec(args.six_tail_mode, lab6, Path(data_dir))
|
| 138 |
+
model, cfg = dps.load_model(args_json, ckpt, device)
|
| 139 |
+
normalize = bool(cfg.get("normalize_labels", True))
|
| 140 |
+
H = int(obs.shape[-2])
|
| 141 |
+
W = int(obs.shape[-1])
|
| 142 |
+
gsz = args.grid
|
| 143 |
+
|
| 144 |
+
full, om_ax, s8_ax = dps.build_full_grid_2d(labs, gsz, tail=tail, lab_dim=ld)
|
| 145 |
+
Wmap, OM, S8 = dps.posterior_weights(
|
| 146 |
+
obs,
|
| 147 |
+
full,
|
| 148 |
+
om_ax,
|
| 149 |
+
s8_ax,
|
| 150 |
+
mean,
|
| 151 |
+
std,
|
| 152 |
+
normalize,
|
| 153 |
+
model,
|
| 154 |
+
H=H,
|
| 155 |
+
W=W,
|
| 156 |
+
device=device,
|
| 157 |
+
grid=gsz,
|
| 158 |
+
batch_sz=args.batch_size,
|
| 159 |
+
ddim_steps=args.ddim_steps,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
wflat = np.clip(Wmap.ravel().astype(np.float64), 0.0, None)
|
| 163 |
+
if wflat.sum() <= 0:
|
| 164 |
+
raise RuntimeError("Posterior masses collapsed to zero.")
|
| 165 |
+
wflat /= wflat.sum()
|
| 166 |
+
omapflat = OM.ravel()
|
| 167 |
+
s8flat = S8.ravel()
|
| 168 |
+
draws = rng.choice(np.arange(len(wflat)), size=args.n_hist, replace=True, p=wflat)
|
| 169 |
+
samp_om = omapflat[draws].astype(np.float64)
|
| 170 |
+
samp_s8 = s8flat[draws].astype(np.float64)
|
| 171 |
+
|
| 172 |
+
out = Path(args.output).resolve()
|
| 173 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 174 |
+
tag = f"ddpm{ld}_{args.six_tail_mode}" if ld == 6 else "ddpm2"
|
| 175 |
+
np.savez_compressed(
|
| 176 |
+
out,
|
| 177 |
+
omega_m=samp_om,
|
| 178 |
+
sigma_8=samp_s8,
|
| 179 |
+
samples=np.column_stack([samp_om, samp_s8]),
|
| 180 |
+
truth_Omega_m=float(lab_t[0]),
|
| 181 |
+
truth_sigma_8=float(lab_t[1]),
|
| 182 |
+
posterior_map=Wmap,
|
| 183 |
+
OM=OM,
|
| 184 |
+
S8=S8,
|
| 185 |
+
index=np.array(ix, dtype=np.int32),
|
| 186 |
+
label_dim=np.array(ld, dtype=np.int16),
|
| 187 |
+
meta_tag=np.array(tag, dtype="U128"),
|
| 188 |
+
six_tail_mode=np.array(args.six_tail_mode if ld == 6 else "", dtype="U16"),
|
| 189 |
+
)
|
| 190 |
+
print("Saved", out, "pairs:", args.n_hist, "device:", device)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
cross_model/scripts/figure6_2409_style.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Figure-6-inspired layout for 2-parameter posteriors (arXiv:2409.09101 style):
|
| 3 |
+
main 2D panel with 1D marginals on adjacent edges — marginal sums vs profiles (max) optional.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import numpy as np
|
| 12 |
+
from matplotlib import gridspec as mgs
|
| 13 |
+
|
| 14 |
+
from sigma_contour_utils import compute_sigma_levels
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def create_figure6_style_plot(
|
| 18 |
+
Wmap: np.ndarray,
|
| 19 |
+
om_ax: np.ndarray,
|
| 20 |
+
s8_ax: np.ndarray,
|
| 21 |
+
*,
|
| 22 |
+
true_param1: float,
|
| 23 |
+
true_param2: float,
|
| 24 |
+
param1_label: str = r"$\Omega_m$",
|
| 25 |
+
param2_label: str = r"$\sigma_8$",
|
| 26 |
+
title: str = "",
|
| 27 |
+
show_profile: bool = False,
|
| 28 |
+
figsize: Tuple[float, float] = (10, 10),
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
Wmap : (G, G) posterior masses on grid (same layout as DDPM OM meshgrid with indexing='ij').
|
| 34 |
+
om_ax, s8_ax : 1-D grids aligned with axes 0 and 1 of ``Wmap``.
|
| 35 |
+
show_profile :
|
| 36 |
+
False → 1D marginals are sums (*marginal*) over the other parameter.
|
| 37 |
+
True → 1D marginals are max (*profile*) over the other parameter (then normalized).
|
| 38 |
+
"""
|
| 39 |
+
p = np.asarray(Wmap, dtype=np.float64)
|
| 40 |
+
p = p / (p.sum() + 1e-30)
|
| 41 |
+
P1, P2 = np.meshgrid(om_ax, s8_ax, indexing="ij")
|
| 42 |
+
|
| 43 |
+
if show_profile:
|
| 44 |
+
m1 = np.max(p, axis=1)
|
| 45 |
+
m2 = np.max(p, axis=0)
|
| 46 |
+
else:
|
| 47 |
+
m1 = p.sum(axis=1)
|
| 48 |
+
m2 = p.sum(axis=0)
|
| 49 |
+
m1 = np.asarray(m1, dtype=np.float64)
|
| 50 |
+
m2 = np.asarray(m2, dtype=np.float64)
|
| 51 |
+
m1 /= m1.max() + 1e-30
|
| 52 |
+
m2 /= m2.max() + 1e-30
|
| 53 |
+
|
| 54 |
+
fig = plt.figure(figsize=figsize)
|
| 55 |
+
gs = mgs.GridSpec(
|
| 56 |
+
nrows=2,
|
| 57 |
+
ncols=2,
|
| 58 |
+
figure=fig,
|
| 59 |
+
width_ratios=[4.0, 1.05],
|
| 60 |
+
height_ratios=[1.05, 4.0],
|
| 61 |
+
wspace=0.035,
|
| 62 |
+
hspace=0.035,
|
| 63 |
+
left=0.12,
|
| 64 |
+
right=0.98,
|
| 65 |
+
bottom=0.1,
|
| 66 |
+
top=0.92,
|
| 67 |
+
)
|
| 68 |
+
ax_main = fig.add_subplot(gs[1, 0])
|
| 69 |
+
ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
|
| 70 |
+
ax_r = fig.add_subplot(gs[1, 1], sharey=ax_main)
|
| 71 |
+
ax_empty = fig.add_subplot(gs[0, 1])
|
| 72 |
+
ax_empty.axis("off")
|
| 73 |
+
|
| 74 |
+
lvl = compute_sigma_levels(p, [0.683, 0.954])
|
| 75 |
+
ax_main.contourf(P1, P2, p, levels=20, cmap="Blues", alpha=0.88)
|
| 76 |
+
if len(set(lvl)) >= 2:
|
| 77 |
+
ax_main.contour(P1, P2, p, levels=lvl, colors=["darkblue", "steelblue"], linewidths=[2.0, 1.5])
|
| 78 |
+
ax_main.scatter(
|
| 79 |
+
true_param1,
|
| 80 |
+
true_param2,
|
| 81 |
+
s=120,
|
| 82 |
+
c="red",
|
| 83 |
+
marker="x",
|
| 84 |
+
linewidths=2.8,
|
| 85 |
+
zorder=15,
|
| 86 |
+
label="true",
|
| 87 |
+
)
|
| 88 |
+
ax_main.set_xlabel(param1_label, fontsize=13)
|
| 89 |
+
ax_main.set_ylabel(param2_label, fontsize=13)
|
| 90 |
+
ax_main.grid(True, alpha=0.28)
|
| 91 |
+
ax_main.legend(fontsize=8, loc="upper right")
|
| 92 |
+
|
| 93 |
+
ax_top.fill_between(om_ax, 0.0, m1, alpha=0.62, color="steelblue")
|
| 94 |
+
ax_top.axvline(true_param1, color="red", ls="--", lw=2.0)
|
| 95 |
+
ax_top.set_ylim(0.0, float(np.max(m1) * 1.12))
|
| 96 |
+
ax_top.set_ylabel("$P(\\mathrm{prof.})$" if show_profile else "$P(\\mathrm{margin.})$", fontsize=10)
|
| 97 |
+
ax_top.tick_params(labelbottom=False)
|
| 98 |
+
ax_top.grid(True, alpha=0.25)
|
| 99 |
+
|
| 100 |
+
ax_r.fill_betweenx(s8_ax, 0.0, m2, alpha=0.62, color="steelblue")
|
| 101 |
+
ax_r.axhline(true_param2, color="red", ls="--", lw=2.0)
|
| 102 |
+
ax_r.set_xlim(0.0, float(np.max(m2) * 1.12))
|
| 103 |
+
ax_r.set_xlabel("$P$", fontsize=10)
|
| 104 |
+
ax_r.tick_params(labelleft=False)
|
| 105 |
+
ax_r.grid(True, alpha=0.25)
|
| 106 |
+
|
| 107 |
+
kind = "Profile" if show_profile else "Marginal"
|
| 108 |
+
fig.suptitle(f"{title} ({kind})", fontsize=14, fontweight="bold", y=0.98)
|
| 109 |
+
|
| 110 |
+
plt.setp(ax_top.get_xticklabels(), visible=False)
|
| 111 |
+
|
| 112 |
+
return fig
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def create_comparison_marginal_vs_profile(
|
| 116 |
+
Wmap: np.ndarray,
|
| 117 |
+
om_ax: np.ndarray,
|
| 118 |
+
s8_ax: np.ndarray,
|
| 119 |
+
*,
|
| 120 |
+
true_param1: float,
|
| 121 |
+
true_param2: float,
|
| 122 |
+
param1_label: str = r"$\Omega_m$",
|
| 123 |
+
param2_label: str = r"$\sigma_8$",
|
| 124 |
+
title: str = "",
|
| 125 |
+
figsize: Tuple[float, float] = (10, 4.2),
|
| 126 |
+
):
|
| 127 |
+
"""Two rows: Ωm and σ8 marginals (sum) vs profile (max) on shared parameter axes."""
|
| 128 |
+
p = np.asarray(Wmap, dtype=np.float64)
|
| 129 |
+
p /= p.sum() + 1e-30
|
| 130 |
+
marg_om = p.sum(axis=1)
|
| 131 |
+
marg_s8 = p.sum(axis=0)
|
| 132 |
+
prof_om = np.max(p, axis=1)
|
| 133 |
+
prof_s8 = np.max(p, axis=0)
|
| 134 |
+
marg_om /= marg_om.sum() + 1e-30
|
| 135 |
+
marg_s8 /= marg_s8.sum() + 1e-30
|
| 136 |
+
prof_om /= prof_om.max() + 1e-30
|
| 137 |
+
prof_s8 /= prof_s8.max() + 1e-30
|
| 138 |
+
|
| 139 |
+
fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=False)
|
| 140 |
+
for ax, xaxis, marg, prof, xlab, xv in zip(
|
| 141 |
+
axes,
|
| 142 |
+
(om_ax, s8_ax),
|
| 143 |
+
(marg_om, marg_s8),
|
| 144 |
+
(prof_om, prof_s8),
|
| 145 |
+
(param1_label, param2_label),
|
| 146 |
+
(true_param1, true_param2),
|
| 147 |
+
):
|
| 148 |
+
ax.plot(xaxis, marg, lw=2.0, ls="-", label="marginal")
|
| 149 |
+
ax.plot(xaxis, prof, lw=2.0, ls="--", label="profile")
|
| 150 |
+
ax.axvline(xv, color="crimson", ls=":", lw=1.8)
|
| 151 |
+
ax.set_xlabel(xlab, fontsize=12)
|
| 152 |
+
ax.set_ylabel("norm. density", fontsize=10)
|
| 153 |
+
ax.legend(fontsize=9)
|
| 154 |
+
ax.grid(True, alpha=0.3)
|
| 155 |
+
fig.suptitle(title, fontsize=12, fontweight="bold")
|
| 156 |
+
fig.tight_layout(rect=(0, 0, 1, 0.93))
|
| 157 |
+
return fig
|
cross_model/scripts/run_ddpm_comparison.sh
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=24:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_compare
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-compare-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-compare-%j.err
|
| 12 |
+
|
| 13 |
+
# DDPM-2 vs DDPM-6 comparison (same cluster layout as 6-param training — see reference below).
|
| 14 |
+
#
|
| 15 |
+
# Reference training script (Slurm + module + paths pattern):
|
| 16 |
+
# /scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh
|
| 17 |
+
#
|
| 18 |
+
# Submit from anywhere:
|
| 19 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
|
| 20 |
+
#
|
| 21 |
+
# Extra CLI args for compare_ddpm_models.py pass through, e.g. LHS off:
|
| 22 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh --skip-lhs-r2
|
| 23 |
+
#
|
| 24 |
+
# Override output dir (optional):
|
| 25 |
+
# sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/ddpm_comparison_out_ab \
|
| 26 |
+
# /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
|
| 27 |
+
#
|
| 28 |
+
# Optional: override DDPM-2 train/val for the combined loss plot (default: bundled JSON in Models/scripts/):
|
| 29 |
+
# sbatch --export=SLURM_2PARAM=/path/to/slurm-2param-ddpm.out \
|
| 30 |
+
# /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_comparison.sh
|
| 31 |
+
#
|
| 32 |
+
# Interactive (same module as training script):
|
| 33 |
+
# module load python/miniconda3-py3.12-usr
|
| 34 |
+
# bash .../run_ddpm_comparison.sh --skip-lhs-r2
|
| 35 |
+
|
| 36 |
+
set -euo pipefail
|
| 37 |
+
|
| 38 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 39 |
+
cd "$ROOT"
|
| 40 |
+
|
| 41 |
+
module load python/miniconda3-py3.12-usr
|
| 42 |
+
|
| 43 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_comparison_out}"
|
| 44 |
+
|
| 45 |
+
echo "==============================================="
|
| 46 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 47 |
+
echo "Job Name: ${SLURM_JOB_NAME:-run_ddpm_comparison}"
|
| 48 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 49 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 50 |
+
echo "Starting Time: $(date)"
|
| 51 |
+
echo "Comparison output: ${OUT}"
|
| 52 |
+
echo "Reference Slurm recipe: april_26/ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh"
|
| 53 |
+
echo "==============================================="
|
| 54 |
+
|
| 55 |
+
PY_ARGS=(python "$ROOT/scripts/compare_ddpm_models.py" --output-dir "$OUT")
|
| 56 |
+
if [[ -n "${SLURM_2PARAM:-}" ]]; then
|
| 57 |
+
PY_ARGS+=(--slurm-2param "${SLURM_2PARAM}")
|
| 58 |
+
fi
|
| 59 |
+
PY_ARGS+=("$@")
|
| 60 |
+
|
| 61 |
+
"${PY_ARGS[@]}"
|
| 62 |
+
|
| 63 |
+
echo "==============================================="
|
| 64 |
+
echo "Artifacts -> ${OUT}"
|
| 65 |
+
echo "Finished at: $(date)"
|
| 66 |
+
echo "==============================================="
|
cross_model/scripts/run_ddpm_figure6.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=24:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_fig6
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-figure6-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-figure6-%j.err
|
| 12 |
+
|
| 13 |
+
# Figure 6 style (arXiv:2409.09101-inspired) surrogate posteriors for DDPM-2 / DDPM-6.
|
| 14 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_figure6.sh
|
| 15 |
+
# sbatch --export=OUTPUT_DIR=/path/to/out,TEST_INDEX=42 .../run_ddpm_figure6.sh --no-six-grid
|
| 16 |
+
#
|
| 17 |
+
set -euo pipefail
|
| 18 |
+
|
| 19 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 20 |
+
cd "$ROOT"
|
| 21 |
+
module load python/miniconda3-py3.12-usr
|
| 22 |
+
|
| 23 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_figure6_out}"
|
| 24 |
+
IDX="${TEST_INDEX:-56}"
|
| 25 |
+
|
| 26 |
+
echo "Job=${SLURM_JOB_ID:-local} OUT=${OUT} TEST_INDEX=${IDX}"
|
| 27 |
+
python "${ROOT}/scripts/run_ddpm_figure6_suite.py" --output-dir "${OUT}" --test-index "${IDX}" "$@"
|
cross_model/scripts/run_ddpm_figure6_suite.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Compute surrogate posteriors and emit Figure-6 style figures (arXiv:2409.09101-inspired).
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import gc
|
| 10 |
+
import sys
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
_SCRIPTS = Path(__file__).resolve().parent
|
| 17 |
+
MODELS_ROOT = Path(__file__).resolve().parents[1]
|
| 18 |
+
CODE_6 = MODELS_ROOT / "6param_ddpm_hi_lh6"
|
| 19 |
+
if str(_SCRIPTS) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(_SCRIPTS))
|
| 21 |
+
if str(CODE_6.resolve()) not in sys.path:
|
| 22 |
+
sys.path.insert(0, str(CODE_6.resolve()))
|
| 23 |
+
|
| 24 |
+
import evaluate_conditional as ec # noqa: E402
|
| 25 |
+
import ddpm_posterior_six_anchors as dps # noqa: E402
|
| 26 |
+
|
| 27 |
+
from ddpm_figure6_integration import ( # noqa: E402
|
| 28 |
+
integrate_figure6_model_comparison,
|
| 29 |
+
integrate_figure6_with_ddpm2,
|
| 30 |
+
integrate_figure6_with_multi_anchor,
|
| 31 |
+
print_integration_guide,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def main() -> None:
|
| 36 |
+
p = argparse.ArgumentParser(description="DDPM Figure-6 style posterior suite.")
|
| 37 |
+
p.add_argument("--output-dir", type=Path, default=MODELS_ROOT / "ddpm_figure6_out")
|
| 38 |
+
p.add_argument("--data-2param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_2"))
|
| 39 |
+
p.add_argument("--data-6param", type=Path, default=Path("/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"))
|
| 40 |
+
p.add_argument(
|
| 41 |
+
"--bundle-2param",
|
| 42 |
+
type=Path,
|
| 43 |
+
default=MODELS_ROOT / "notebook_model_weights" / "2param_epoch200",
|
| 44 |
+
)
|
| 45 |
+
p.add_argument(
|
| 46 |
+
"--bundle-6param",
|
| 47 |
+
type=Path,
|
| 48 |
+
default=MODELS_ROOT / "notebook_model_weights" / "6param_best",
|
| 49 |
+
)
|
| 50 |
+
p.add_argument("--split", type=str, default="test", choices=["train", "val", "test"])
|
| 51 |
+
p.add_argument("--test-index", type=int, default=56, help="Index for single comparison + per-map fig6.")
|
| 52 |
+
p.add_argument("--grid", type=int, default=14)
|
| 53 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 54 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 55 |
+
p.add_argument(
|
| 56 |
+
"--six-anchors-only",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="Only 2×3 multi-anchor plots (skip triple model comparison at --test-index).",
|
| 59 |
+
)
|
| 60 |
+
p.add_argument(
|
| 61 |
+
"--no-six-grid",
|
| 62 |
+
action="store_true",
|
| 63 |
+
help="Skip multi-anchor 2×3 panels.",
|
| 64 |
+
)
|
| 65 |
+
p.add_argument(
|
| 66 |
+
"--no-single-fig6",
|
| 67 |
+
action="store_true",
|
| 68 |
+
help="Skip per-map marginal/profile for test-index on DDPM-2 and DDPM-6 (truth tail).",
|
| 69 |
+
)
|
| 70 |
+
p.add_argument(
|
| 71 |
+
"--guide",
|
| 72 |
+
action="store_true",
|
| 73 |
+
help="Print markdown-style integration notes and exit.",
|
| 74 |
+
)
|
| 75 |
+
args = p.parse_args()
|
| 76 |
+
|
| 77 |
+
if args.guide:
|
| 78 |
+
print_integration_guide()
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
out = Path(args.output_dir).resolve()
|
| 82 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
+
print("device:", device)
|
| 86 |
+
|
| 87 |
+
data2 = Path(args.data_2param)
|
| 88 |
+
data6 = Path(args.data_6param)
|
| 89 |
+
imgs2, lab2 = ec.load_split(data2, args.split)
|
| 90 |
+
imgs6, lab6 = ec.load_split(data6, args.split)
|
| 91 |
+
n = min(len(lab2), len(lab6))
|
| 92 |
+
anchor_ix = np.linspace(0, n - 1, num=6, dtype=int)
|
| 93 |
+
|
| 94 |
+
low_tail, hi_tail = dps.tail_lhs_bounds(data6)
|
| 95 |
+
|
| 96 |
+
ck2 = args.bundle_2param / "checkpoint_epoch_200.pt"
|
| 97 |
+
aj2 = args.bundle_2param / "args.json"
|
| 98 |
+
ck6 = args.bundle_6param / "best_model.pt"
|
| 99 |
+
aj6 = args.bundle_6param / "args.json"
|
| 100 |
+
|
| 101 |
+
mean2, std2 = ec.load_label_stats(data2)
|
| 102 |
+
mean6, std6 = ec.load_label_stats(data6)
|
| 103 |
+
|
| 104 |
+
ix = int(args.test_index)
|
| 105 |
+
if not (0 <= ix < n):
|
| 106 |
+
raise SystemExit(f"--test-index {ix} invalid (max {n - 1})")
|
| 107 |
+
|
| 108 |
+
lab_box = lab6[:, :2].copy()
|
| 109 |
+
|
| 110 |
+
if not args.six_anchors_only:
|
| 111 |
+
print(">>> Loading models for ix=", ix, "...")
|
| 112 |
+
m2, c2 = dps.load_model(aj2, ck2, device)
|
| 113 |
+
m6, c6 = dps.load_model(aj6, ck6, device)
|
| 114 |
+
normalize2 = bool(c2.get("normalize_labels", True))
|
| 115 |
+
normalize6 = bool(c6.get("normalize_labels", True))
|
| 116 |
+
|
| 117 |
+
obs2 = imgs2[ix]
|
| 118 |
+
obs6 = imgs6[ix]
|
| 119 |
+
lt2 = lab2[ix].astype(np.float64)
|
| 120 |
+
lt6 = lab6[ix].astype(np.float64)
|
| 121 |
+
ta2om, ta2s8 = float(lt2[0]), float(lt2[1])
|
| 122 |
+
tom, ts8 = float(lt6[0]), float(lt6[1])
|
| 123 |
+
|
| 124 |
+
full2, om_ax, s8_ax = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
|
| 125 |
+
Wm2, _, _ = dps.posterior_weights(
|
| 126 |
+
obs2,
|
| 127 |
+
full2,
|
| 128 |
+
om_ax,
|
| 129 |
+
s8_ax,
|
| 130 |
+
mean2,
|
| 131 |
+
std2,
|
| 132 |
+
normalize2,
|
| 133 |
+
m2,
|
| 134 |
+
H=int(obs2.shape[-2]),
|
| 135 |
+
W=int(obs2.shape[-1]),
|
| 136 |
+
device=device,
|
| 137 |
+
grid=args.grid,
|
| 138 |
+
batch_sz=args.batch_size,
|
| 139 |
+
ddim_steps=args.ddim_steps,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
full6truth, om6, s86 = dps.build_full_grid_2d(
|
| 143 |
+
lab6, args.grid, tail=lab6[ix, 2:6].astype(np.float32), lab_dim=6
|
| 144 |
+
)
|
| 145 |
+
Wm6t, _, _ = dps.posterior_weights(
|
| 146 |
+
obs6,
|
| 147 |
+
full6truth,
|
| 148 |
+
om6,
|
| 149 |
+
s86,
|
| 150 |
+
mean6,
|
| 151 |
+
std6,
|
| 152 |
+
normalize6,
|
| 153 |
+
m6,
|
| 154 |
+
H=int(obs6.shape[-2]),
|
| 155 |
+
W=int(obs6.shape[-1]),
|
| 156 |
+
device=device,
|
| 157 |
+
grid=args.grid,
|
| 158 |
+
batch_sz=args.batch_size,
|
| 159 |
+
ddim_steps=args.ddim_steps,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
full6lo, om_b, s8_b = dps.build_full_grid_2d(lab6, args.grid, tail=low_tail, lab_dim=6)
|
| 163 |
+
Wm6lo, _, _ = dps.posterior_weights(
|
| 164 |
+
obs6,
|
| 165 |
+
full6lo,
|
| 166 |
+
om_b,
|
| 167 |
+
s8_b,
|
| 168 |
+
mean6,
|
| 169 |
+
std6,
|
| 170 |
+
normalize6,
|
| 171 |
+
m6,
|
| 172 |
+
H=int(obs6.shape[-2]),
|
| 173 |
+
W=int(obs6.shape[-1]),
|
| 174 |
+
device=device,
|
| 175 |
+
grid=args.grid,
|
| 176 |
+
batch_sz=args.batch_size,
|
| 177 |
+
ddim_steps=args.ddim_steps,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
full6hi, om_c, s8_c = dps.build_full_grid_2d(lab6, args.grid, tail=hi_tail, lab_dim=6)
|
| 181 |
+
Wm6hi, _, _ = dps.posterior_weights(
|
| 182 |
+
obs6,
|
| 183 |
+
full6hi,
|
| 184 |
+
om_c,
|
| 185 |
+
s8_c,
|
| 186 |
+
mean6,
|
| 187 |
+
std6,
|
| 188 |
+
normalize6,
|
| 189 |
+
m6,
|
| 190 |
+
H=int(obs6.shape[-2]),
|
| 191 |
+
W=int(obs6.shape[-1]),
|
| 192 |
+
device=device,
|
| 193 |
+
grid=args.grid,
|
| 194 |
+
batch_sz=args.batch_size,
|
| 195 |
+
ddim_steps=args.ddim_steps,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if not (np.allclose(om_ax, om_b, rtol=0, atol=1e-12) and np.allclose(s8_ax, s86)):
|
| 199 |
+
print("Warning: Ωm–σ8 grids differ between setups; plotting uses DDPM-2 Ωm/σ8 axes.")
|
| 200 |
+
|
| 201 |
+
integrate_figure6_model_comparison(
|
| 202 |
+
{
|
| 203 |
+
"DDPM-2": Wm2,
|
| 204 |
+
"DDPM-6 (truth-tail)": Wm6t,
|
| 205 |
+
"DDPM-6 (min-tail)": Wm6lo,
|
| 206 |
+
"DDPM-6 (max-tail)": Wm6hi,
|
| 207 |
+
},
|
| 208 |
+
om_ax,
|
| 209 |
+
s8_ax,
|
| 210 |
+
tom,
|
| 211 |
+
ts8,
|
| 212 |
+
ix,
|
| 213 |
+
out,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if not args.no_single_fig6:
|
| 217 |
+
integrate_figure6_with_ddpm2(Wm2, om_ax, s8_ax, ta2om, ta2s8, ix, out, model_name="DDPM-2")
|
| 218 |
+
integrate_figure6_with_ddpm2(Wm6t, om_ax, s8_ax, tom, ts8, ix, out, model_name="DDPM-6-truth")
|
| 219 |
+
|
| 220 |
+
del m2, m6
|
| 221 |
+
gc.collect()
|
| 222 |
+
if torch.cuda.is_available():
|
| 223 |
+
torch.cuda.empty_cache()
|
| 224 |
+
|
| 225 |
+
# --- Six anchors: multi grids for DDPM-2 + DDPM-6 truth tail ---
|
| 226 |
+
if not args.no_six_grid:
|
| 227 |
+
print(">>> Six-anchor Figure 6 grids...")
|
| 228 |
+
model2, cfg2 = dps.load_model(aj2, ck2, device)
|
| 229 |
+
model6, cfg6 = dps.load_model(aj6, ck6, device)
|
| 230 |
+
nz2 = bool(cfg2.get("normalize_labels", True))
|
| 231 |
+
nz6 = bool(cfg6.get("normalize_labels", True))
|
| 232 |
+
|
| 233 |
+
post2: list[np.ndarray] = []
|
| 234 |
+
post6: list[np.ndarray] = []
|
| 235 |
+
truths: list[tuple[float, float]] = []
|
| 236 |
+
indices: list[int] = []
|
| 237 |
+
|
| 238 |
+
for k, jx in enumerate(anchor_ix.ravel()):
|
| 239 |
+
indices.append(int(jx))
|
| 240 |
+
o2 = imgs2[jx]
|
| 241 |
+
lb2 = lab2[jx].astype(np.float64)
|
| 242 |
+
f2, oa, sa = dps.build_full_grid_2d(lab_box, args.grid, tail=None, lab_dim=2)
|
| 243 |
+
W2, _, _ = dps.posterior_weights(
|
| 244 |
+
o2,
|
| 245 |
+
f2,
|
| 246 |
+
oa,
|
| 247 |
+
sa,
|
| 248 |
+
mean2,
|
| 249 |
+
std2,
|
| 250 |
+
nz2,
|
| 251 |
+
model2,
|
| 252 |
+
H=int(o2.shape[-2]),
|
| 253 |
+
W=int(o2.shape[-1]),
|
| 254 |
+
device=device,
|
| 255 |
+
grid=args.grid,
|
| 256 |
+
batch_sz=args.batch_size,
|
| 257 |
+
ddim_steps=args.ddim_steps,
|
| 258 |
+
)
|
| 259 |
+
post2.append(W2)
|
| 260 |
+
|
| 261 |
+
o6 = imgs6[jx]
|
| 262 |
+
lb6 = lab6[jx]
|
| 263 |
+
tail_truth = lb6.astype(np.float32)[2:6]
|
| 264 |
+
f6, oa6, sa6 = dps.build_full_grid_2d(lab6, args.grid, tail=tail_truth, lab_dim=6)
|
| 265 |
+
W6, _, _ = dps.posterior_weights(
|
| 266 |
+
o6,
|
| 267 |
+
f6,
|
| 268 |
+
oa6,
|
| 269 |
+
sa6,
|
| 270 |
+
mean6,
|
| 271 |
+
std6,
|
| 272 |
+
nz6,
|
| 273 |
+
model6,
|
| 274 |
+
H=int(o6.shape[-2]),
|
| 275 |
+
W=int(o6.shape[-1]),
|
| 276 |
+
device=device,
|
| 277 |
+
grid=args.grid,
|
| 278 |
+
batch_sz=args.batch_size,
|
| 279 |
+
ddim_steps=args.ddim_steps,
|
| 280 |
+
)
|
| 281 |
+
post6.append(W6)
|
| 282 |
+
truths.append((float(lb2[0]), float(lb2[1])))
|
| 283 |
+
|
| 284 |
+
integrate_figure6_with_multi_anchor(
|
| 285 |
+
post2,
|
| 286 |
+
oa,
|
| 287 |
+
sa,
|
| 288 |
+
truths,
|
| 289 |
+
indices,
|
| 290 |
+
out,
|
| 291 |
+
model_name="DDPM-2",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
truths6 = [(float(lab6[int(j)][0]), float(lab6[int(j)][1])) for j in anchor_ix]
|
| 295 |
+
|
| 296 |
+
integrate_figure6_with_multi_anchor(
|
| 297 |
+
post6,
|
| 298 |
+
oa6,
|
| 299 |
+
sa6,
|
| 300 |
+
truths6,
|
| 301 |
+
indices,
|
| 302 |
+
out,
|
| 303 |
+
model_name="DDPM-6-truth-tail",
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
del model2, model6
|
| 307 |
+
gc.collect()
|
| 308 |
+
if torch.cuda.is_available():
|
| 309 |
+
torch.cuda.empty_cache()
|
| 310 |
+
|
| 311 |
+
print(f"Done. Outputs in {out}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == "__main__":
|
| 315 |
+
main()
|
cross_model/scripts/run_ddpm_posterior_corrected.sh
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_post_corr
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-post-corr-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-post-corr-%j.err
|
| 12 |
+
|
| 13 |
+
# Prior / likelihood / posterior visualization pipeline (ddpm_posterior_corrected.py).
|
| 14 |
+
# Separate from poster.py — default output dir is ddpm_posterior_corrected_fullviz_out.
|
| 15 |
+
#
|
| 16 |
+
# Submit:
|
| 17 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_corrected.sh
|
| 18 |
+
#
|
| 19 |
+
# Extra args pass through to the Python script:
|
| 20 |
+
# sbatch .../run_ddpm_posterior_corrected.sh --ddpm2-only --grid 20 --n-ddpm-samples 4 --no-ppc
|
| 21 |
+
#
|
| 22 |
+
# Override dirs:
|
| 23 |
+
# sbatch --export=OUTPUT_DIR=/path/to/out,CUSTOM_LOG=/path/run.log \\
|
| 24 |
+
# .../run_ddpm_posterior_corrected.sh
|
| 25 |
+
#
|
| 26 |
+
# Interactive:
|
| 27 |
+
# module load python/miniconda3-py3.12-usr
|
| 28 |
+
# bash .../run_ddpm_posterior_corrected.sh --help
|
| 29 |
+
|
| 30 |
+
set -euo pipefail
|
| 31 |
+
|
| 32 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 33 |
+
cd "$ROOT"
|
| 34 |
+
|
| 35 |
+
module load python/miniconda3-py3.12-usr
|
| 36 |
+
|
| 37 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_corrected_fullviz_out}"
|
| 38 |
+
mkdir -p "$OUT"
|
| 39 |
+
|
| 40 |
+
# Copy of stdout/stderr for progress (also appears in Slurm .out/.err)
|
| 41 |
+
RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
|
| 42 |
+
|
| 43 |
+
echo "==============================================="
|
| 44 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 45 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 46 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 47 |
+
echo "Started: $(date)"
|
| 48 |
+
echo "Output dir: ${OUT}"
|
| 49 |
+
echo "Progress log (tee): ${RUN_LOG}"
|
| 50 |
+
echo "==============================================="
|
| 51 |
+
|
| 52 |
+
set -o pipefail
|
| 53 |
+
python -u "${ROOT}/ddpm_posterior_corrected.py" --output-dir "${OUT}" "$@" 2>&1 | tee -a "${RUN_LOG}"
|
| 54 |
+
|
| 55 |
+
echo "==============================================="
|
| 56 |
+
echo "Finished: $(date)"
|
| 57 |
+
echo "Figures & log → ${OUT}"
|
| 58 |
+
echo "==============================================="
|
cross_model/scripts/run_ddpm_posterior_six_anchors.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_post6
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-posterior-six-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-posterior-six-%j.err
|
| 12 |
+
|
| 13 |
+
# Six-anchor surrogate posteriors (DDPM-2 + DDPM-6 with extra dims min/max).
|
| 14 |
+
#
|
| 15 |
+
# Submit from anywhere:
|
| 16 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_six_anchors.sh
|
| 17 |
+
#
|
| 18 |
+
# Override output directory:
|
| 19 |
+
# sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/my_post_out \
|
| 20 |
+
# /scratch/mrpcol001/Diffusion_job/Models/scripts/run_ddpm_posterior_six_anchors.sh
|
| 21 |
+
#
|
| 22 |
+
# Extra CLI passes through to ddpm_posterior_six_anchors.py, e.g. only DDPM-6 panels:
|
| 23 |
+
# sbatch .../run_ddpm_posterior_six_anchors.sh --ddpm6-only --grid 12
|
| 24 |
+
#
|
| 25 |
+
# Interactive:
|
| 26 |
+
# module load python/miniconda3-py3.12-usr
|
| 27 |
+
# bash .../run_ddpm_posterior_six_anchors.sh
|
| 28 |
+
|
| 29 |
+
set -euo pipefail
|
| 30 |
+
|
| 31 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 32 |
+
cd "$ROOT"
|
| 33 |
+
|
| 34 |
+
module load python/miniconda3-py3.12-usr
|
| 35 |
+
|
| 36 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_six_anchors_out}"
|
| 37 |
+
|
| 38 |
+
echo "==============================================="
|
| 39 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 40 |
+
echo "Job Name: ${SLURM_JOB_NAME:-run_ddpm_posterior_six_anchors}"
|
| 41 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 42 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 43 |
+
echo "Starting Time: $(date)"
|
| 44 |
+
echo "Posterior output: ${OUT}"
|
| 45 |
+
echo "==============================================="
|
| 46 |
+
|
| 47 |
+
python "$ROOT/scripts/ddpm_posterior_six_anchors.py" --output-dir "$OUT" "$@"
|
| 48 |
+
|
| 49 |
+
echo "==============================================="
|
| 50 |
+
echo "Artifacts -> ${OUT}"
|
| 51 |
+
echo "Finished at: $(date)"
|
| 52 |
+
echo "==============================================="
|
cross_model/scripts/run_poster.sh
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=24:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_poster
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-poster-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-poster-%j.err
|
| 12 |
+
|
| 13 |
+
# Corrected six-anchor surrogate posteriors (poster.py): DDPM-2 + DDPM-6 with
|
| 14 |
+
# stochastic averaging, calibrated sigma_pk, MC marginalisation for 6-param, etc.
|
| 15 |
+
#
|
| 16 |
+
# Submit from anywhere:
|
| 17 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh
|
| 18 |
+
#
|
| 19 |
+
# Override output directory:
|
| 20 |
+
# sbatch --export=OUTPUT_DIR=/scratch/mrpcol001/Diffusion_job/Models/my_poster_out \
|
| 21 |
+
# /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh
|
| 22 |
+
#
|
| 23 |
+
# Extra CLI passes through to poster.py, e.g. DDPM-2 only (faster debug):
|
| 24 |
+
# sbatch .../run_poster.sh --ddpm2-only --grid 14 --n-pk-samples 4 --n-marg-samples 1 --no-ppc
|
| 25 |
+
#
|
| 26 |
+
# Interactive (same module as other Models scripts):
|
| 27 |
+
# module load python/miniconda3-py3.12-usr
|
| 28 |
+
# bash /scratch/mrpcol001/Diffusion_job/Models/scripts/run_poster.sh --help
|
| 29 |
+
|
| 30 |
+
set -euo pipefail
|
| 31 |
+
|
| 32 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 33 |
+
cd "$ROOT"
|
| 34 |
+
|
| 35 |
+
module load python/miniconda3-py3.12-usr
|
| 36 |
+
|
| 37 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_posterior_corrected_out}"
|
| 38 |
+
|
| 39 |
+
echo "==============================================="
|
| 40 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 41 |
+
echo "Job Name: ${SLURM_JOB_NAME:-run_poster}"
|
| 42 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 43 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 44 |
+
echo "Starting Time: $(date)"
|
| 45 |
+
echo "Poster output: ${OUT}"
|
| 46 |
+
echo "==============================================="
|
| 47 |
+
|
| 48 |
+
python "$ROOT/poster.py" --output-dir "$OUT" "$@"
|
| 49 |
+
|
| 50 |
+
echo "==============================================="
|
| 51 |
+
echo "Artifacts -> ${OUT}"
|
| 52 |
+
echo "Finished at: $(date)"
|
| 53 |
+
echo "==============================================="
|
cross_model/scripts/run_posterior_inference.sh
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --job-name=vlb_infer
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-vlb-infer-%j.out
|
| 11 |
+
#SBATCH --error=slurm-vlb-infer-%j.err
|
| 12 |
+
|
| 13 |
+
# VLB / Mudur-style posterior_inference.py (pure inference-time L_t surfaces).
|
| 14 |
+
#
|
| 15 |
+
# Defaults match bundled 6-param checkpoint + LH test data (override via env).
|
| 16 |
+
#
|
| 17 |
+
# Submit:
|
| 18 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/scripts/run_posterior_inference.sh
|
| 19 |
+
#
|
| 20 |
+
# Defaults (posterior_inference.py): n_fields=9, grid_size=10000 (needs --allow_huge_grid),
|
| 21 |
+
# mosaic figure posterior_L0_mosaic_3x3.png at ~10000×10000 px.
|
| 22 |
+
# Override grid without huge scan, e.g.: --grid_size 50 (then --allow_huge_grid not needed)
|
| 23 |
+
# Smoke test:
|
| 24 |
+
# sbatch .../run_posterior_inference.sh --n_fields 1 --grid_size 25 --t_subset 0 --batch_size 16
|
| 25 |
+
#
|
| 26 |
+
# Custom checkpoint / args / data:
|
| 27 |
+
# sbatch --export=CHECKPOINT=/path/best_model.pt,TRAINING_ARGS=/path/args.json,DATA_DIR=/path/params_6 \\
|
| 28 |
+
# .../run_posterior_inference.sh --grid_size 40
|
| 29 |
+
#
|
| 30 |
+
# Logs: Slurm .out/.err plus OUTPUT_DIR/run_log.txt (override CUSTOM_LOG).
|
| 31 |
+
|
| 32 |
+
set -euo pipefail
|
| 33 |
+
|
| 34 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 35 |
+
cd "$ROOT"
|
| 36 |
+
|
| 37 |
+
module load python/miniconda3-py3.12-usr
|
| 38 |
+
|
| 39 |
+
PY="${ROOT}/6param_ddpm_hi_lh6/posterior_inference.py"
|
| 40 |
+
OUT="${OUTPUT_DIR:-${ROOT}/vlb_inference_outputs}"
|
| 41 |
+
|
| 42 |
+
CHK="${CHECKPOINT:-${ROOT}/notebook_model_weights/6param_best/best_model.pt}"
|
| 43 |
+
ARGS="${TRAINING_ARGS:-${ROOT}/notebook_model_weights/6param_best/args.json}"
|
| 44 |
+
DATA="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 45 |
+
|
| 46 |
+
mkdir -p "${OUT}"
|
| 47 |
+
RUN_LOG="${CUSTOM_LOG:-${OUT}/run_log.txt}"
|
| 48 |
+
|
| 49 |
+
echo "==============================================="
|
| 50 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 51 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 52 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 53 |
+
echo "Started: $(date)"
|
| 54 |
+
echo "Python: ${PY}"
|
| 55 |
+
echo "checkpoint: ${CHK}"
|
| 56 |
+
echo "training_args: ${ARGS}"
|
| 57 |
+
echo "data_dir: ${DATA}"
|
| 58 |
+
echo "output_dir: ${OUT}"
|
| 59 |
+
echo "Progress log: ${RUN_LOG}"
|
| 60 |
+
echo "==============================================="
|
| 61 |
+
|
| 62 |
+
set -o pipefail
|
| 63 |
+
python -u "${PY}" \
|
| 64 |
+
--checkpoint "${CHK}" \
|
| 65 |
+
--training_args "${ARGS}" \
|
| 66 |
+
--data_dir "${DATA}" \
|
| 67 |
+
--output_dir "${OUT}" \
|
| 68 |
+
--allow_huge_grid \
|
| 69 |
+
"$@" 2>&1 | tee -a "${RUN_LOG}"
|
| 70 |
+
|
| 71 |
+
echo "==============================================="
|
| 72 |
+
echo "Finished: $(date)"
|
| 73 |
+
echo "Artifacts → ${OUT}"
|
| 74 |
+
echo "==============================================="
|
cross_model/scripts/run_triangle_ddpm_both.sh
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_triangle
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-ddpm-triangle-%j.out
|
| 11 |
+
#SBATCH --error=slurm-ddpm-triangle-%j.err
|
| 12 |
+
|
| 13 |
+
# Run surrogate Ωm–σ8 chain export for DDPM-2 and DDPM-6, then a joint triangle plot.
|
| 14 |
+
# If you have your own copies of ddpm_triangle_integration.py / triangle_plot_posterior.py
|
| 15 |
+
# under $ROOT/scripts (e.g. copied from ~/Downloads), they override the repo versions.
|
| 16 |
+
#
|
| 17 |
+
# sbatch .../run_triangle_ddpm_both.sh
|
| 18 |
+
#
|
| 19 |
+
# sbatch --export=OUTPUT_DIR=/path/to/out,TEST_INDEX=56 .../run_triangle_ddpm_both.sh
|
| 20 |
+
#
|
| 21 |
+
# Interactive:
|
| 22 |
+
# module load python/miniconda3-py3.12-usr
|
| 23 |
+
# bash .../run_triangle_ddpm_both.sh
|
| 24 |
+
|
| 25 |
+
set -euo pipefail
|
| 26 |
+
|
| 27 |
+
ROOT="/scratch/mrpcol001/Diffusion_job/Models"
|
| 28 |
+
cd "$ROOT"
|
| 29 |
+
|
| 30 |
+
module load python/miniconda3-py3.12-usr
|
| 31 |
+
|
| 32 |
+
OUT="${OUTPUT_DIR:-${ROOT}/ddpm_triangle_out}"
|
| 33 |
+
TEST_IX="${TEST_INDEX:-56}"
|
| 34 |
+
GRID="${GRID_POINTS:-14}"
|
| 35 |
+
|
| 36 |
+
mkdir -p "${OUT}"
|
| 37 |
+
|
| 38 |
+
INTEG="${DDPM_TRIANGLE_INTEGRATION_PY:-${ROOT}/scripts/ddpm_triangle_integration.py}"
|
| 39 |
+
PLOT="${DDPM_TRIANGLE_POSTERIOR_PY:-${ROOT}/scripts/triangle_plot_posterior.py}"
|
| 40 |
+
|
| 41 |
+
for f in "$INTEG" "$PLOT"; do
|
| 42 |
+
if [[ ! -f "$f" ]]; then
|
| 43 |
+
echo "Missing: $f"
|
| 44 |
+
exit 1
|
| 45 |
+
fi
|
| 46 |
+
done
|
| 47 |
+
|
| 48 |
+
CHAIN2="${OUT}/chain_surrogate_ix${TEST_IX}_ddpm2.npz"
|
| 49 |
+
CHAIN6="${OUT}/chain_surrogate_ix${TEST_IX}_ddpm6_truth_tail.npz"
|
| 50 |
+
|
| 51 |
+
echo "==============================================="
|
| 52 |
+
echo "Job: ${SLURM_JOB_ID:-local} OUT=${OUT} test_ix=${TEST_IX}"
|
| 53 |
+
echo "==============================================="
|
| 54 |
+
|
| 55 |
+
python "$INTEG" \
|
| 56 |
+
--label-dim 2 \
|
| 57 |
+
--test-index "${TEST_IX}" \
|
| 58 |
+
--grid "${GRID}" \
|
| 59 |
+
-o "${CHAIN2}"
|
| 60 |
+
|
| 61 |
+
python "$INTEG" \
|
| 62 |
+
--label-dim 6 \
|
| 63 |
+
--test-index "${TEST_IX}" \
|
| 64 |
+
--six-tail-mode truth \
|
| 65 |
+
--grid "${GRID}" \
|
| 66 |
+
-o "${CHAIN6}"
|
| 67 |
+
|
| 68 |
+
python "$PLOT" \
|
| 69 |
+
-i "${CHAIN2}" "${CHAIN6}" \
|
| 70 |
+
--labels "DDPM-2" "DDPM-6" \
|
| 71 |
+
-o "${OUT}/triangle_ddpm2_ddpm6_ix${TEST_IX}.png"
|
| 72 |
+
|
| 73 |
+
echo "Chains: ${CHAIN2} ${CHAIN6}"
|
| 74 |
+
echo "Triangle: ${OUT}/triangle_ddpm2_ddpm6_ix${TEST_IX}.png"
|
| 75 |
+
echo "Finished: $(date)"
|
cross_model/scripts/sigma_contour_utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HDR-style contour levels for 2D probability maps on a grid."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def compute_sigma_levels(
|
| 9 |
+
posterior_norm: np.ndarray,
|
| 10 |
+
credibility_mass: tuple[float, ...] | list[float],
|
| 11 |
+
) -> list[float]:
|
| 12 |
+
"""
|
| 13 |
+
Highest-density containment: find density thresholds such that descending sort
|
| 14 |
+
of mass covers ``credibility_mass[j]`` of total probability.
|
| 15 |
+
|
| 16 |
+
Returned levels are ascending (suitable order for matplotlib ``contour``).
|
| 17 |
+
"""
|
| 18 |
+
p = np.asarray(posterior_norm, dtype=np.float64).ravel()
|
| 19 |
+
s = p.sum()
|
| 20 |
+
if s <= 0:
|
| 21 |
+
return [0.0 for _ in credibility_mass]
|
| 22 |
+
ps = np.sort((p / s).flatten())[::-1]
|
| 23 |
+
cdf = np.cumsum(ps)
|
| 24 |
+
out: list[float] = []
|
| 25 |
+
for cred in credibility_mass:
|
| 26 |
+
j = int(np.searchsorted(cdf, cred, side="left"))
|
| 27 |
+
j = min(max(j, 0), len(ps) - 1)
|
| 28 |
+
out.append(float(ps[j]))
|
| 29 |
+
return sorted(out)
|
cross_model/scripts/triangle_plot_posterior.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Corner-style triangle plot for surrogate $(\\Omega_m,\\sigma_8)$ chains from ``ddpm_triangle_integration.py``.
|
| 4 |
+
|
| 5 |
+
Loads one or two ``.npz`` files (keys ``omega_m``, ``sigma_8`` / ``samples``, ``truth_*``) and draws
|
| 6 |
+
1D marginals + 2D density. If you substitute a script from your Downloads, keep ``--inputs``
|
| 7 |
+
and the expected ``.npz`` keys compatible.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
import matplotlib
|
| 16 |
+
|
| 17 |
+
matplotlib.use("Agg")
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _load_chain(path: Path) -> tuple[np.ndarray, np.ndarray, tuple[float, float] | None]:
|
| 23 |
+
d = np.load(path, allow_pickle=True)
|
| 24 |
+
if "samples" in d:
|
| 25 |
+
s = np.asarray(d["samples"], dtype=np.float64)
|
| 26 |
+
om, s8 = s[:, 0], s[:, 1]
|
| 27 |
+
else:
|
| 28 |
+
om = np.asarray(d["omega_m"], dtype=np.float64).ravel()
|
| 29 |
+
s8 = np.asarray(d["sigma_8"], dtype=np.float64).ravel()
|
| 30 |
+
truth = None
|
| 31 |
+
if "truth_Omega_m" in d.files and "truth_sigma_8" in d.files:
|
| 32 |
+
truth = (float(d["truth_Omega_m"]), float(d["truth_sigma_8"]))
|
| 33 |
+
return om, s8, truth
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main() -> None:
|
| 37 |
+
p = argparse.ArgumentParser(description="Triangle / corner plot for Ωm–σ8 surrogate chains.")
|
| 38 |
+
p.add_argument(
|
| 39 |
+
"--inputs",
|
| 40 |
+
"-i",
|
| 41 |
+
nargs="+",
|
| 42 |
+
type=Path,
|
| 43 |
+
required=True,
|
| 44 |
+
help="One or two .npz outputs from ddpm_triangle_integration.py",
|
| 45 |
+
)
|
| 46 |
+
p.add_argument(
|
| 47 |
+
"--labels",
|
| 48 |
+
nargs="*",
|
| 49 |
+
default=None,
|
| 50 |
+
help="Legend entries (default: paths' stems).",
|
| 51 |
+
)
|
| 52 |
+
p.add_argument(
|
| 53 |
+
"--output",
|
| 54 |
+
"-o",
|
| 55 |
+
type=Path,
|
| 56 |
+
default=None,
|
| 57 |
+
help="Output PNG (default: triangle_posterior_ddpm2_ddpm6.png next to first input).",
|
| 58 |
+
)
|
| 59 |
+
p.add_argument("--bins-1d", type=int, default=40)
|
| 60 |
+
p.add_argument("--bins-2d", type=int, default=45)
|
| 61 |
+
args = p.parse_args()
|
| 62 |
+
|
| 63 |
+
paths = [Path(x).resolve() for x in args.inputs]
|
| 64 |
+
names = args.labels if args.labels else [p.stem for p in paths]
|
| 65 |
+
if len(names) != len(paths):
|
| 66 |
+
raise SystemExit("--labels count must match --inputs")
|
| 67 |
+
|
| 68 |
+
colors = ("#1f77b4", "#d95f02", "#2ca02c")
|
| 69 |
+
fig = plt.figure(figsize=(8.2, 8.0))
|
| 70 |
+
|
| 71 |
+
ax00 = fig.add_axes([0.1, 0.55, 0.35, 0.35])
|
| 72 |
+
ax_cont = fig.add_axes([0.1, 0.1, 0.35, 0.35])
|
| 73 |
+
ax11 = fig.add_axes([0.55, 0.1, 0.35, 0.35])
|
| 74 |
+
ax_blank = fig.add_axes([0.55, 0.55, 0.35, 0.35])
|
| 75 |
+
ax_blank.axis("off")
|
| 76 |
+
|
| 77 |
+
for i, path in enumerate(paths):
|
| 78 |
+
om, s8, truth = _load_chain(path)
|
| 79 |
+
c = colors[i % len(colors)]
|
| 80 |
+
ax00.hist(
|
| 81 |
+
om,
|
| 82 |
+
bins=args.bins_1d,
|
| 83 |
+
density=True,
|
| 84 |
+
histtype="step",
|
| 85 |
+
color=c,
|
| 86 |
+
lw=2.0,
|
| 87 |
+
label=names[i],
|
| 88 |
+
)
|
| 89 |
+
ax11.hist(
|
| 90 |
+
s8,
|
| 91 |
+
bins=args.bins_1d,
|
| 92 |
+
density=True,
|
| 93 |
+
histtype="step",
|
| 94 |
+
color=c,
|
| 95 |
+
lw=2.0,
|
| 96 |
+
)
|
| 97 |
+
h2, xe, ye = np.histogram2d(om, s8, bins=args.bins_2d, density=True)
|
| 98 |
+
xc = 0.5 * (xe[1:] + xe[:-1])
|
| 99 |
+
yc = 0.5 * (ye[1:] + ye[:-1])
|
| 100 |
+
X, Y = np.meshgrid(xc, yc, indexing="ij")
|
| 101 |
+
Z = np.ma.masked_where(h2.T <= 1e-20, h2.T)
|
| 102 |
+
if i == 0 and np.ma.count(Z) > 0:
|
| 103 |
+
cf = ax_cont.contourf(X, Y, Z, alpha=0.45, cmap="Blues")
|
| 104 |
+
fig.colorbar(cf, ax=ax_cont, fraction=0.046, pad=0.04)
|
| 105 |
+
elif np.ma.count(Z) > 0:
|
| 106 |
+
ax_cont.contour(X, Y, Z, colors=[c], linewidths=[1.85])
|
| 107 |
+
if truth:
|
| 108 |
+
tx, ty = truth
|
| 109 |
+
ax_cont.scatter(tx, ty, marker="x", s=88, color=c, zorder=6)
|
| 110 |
+
|
| 111 |
+
ax00.set_title(r"$P(\Omega_m)$ marginal")
|
| 112 |
+
ax00.set_ylabel("density")
|
| 113 |
+
ax00.legend(fontsize=8, loc="upper right")
|
| 114 |
+
ax_cont.set_title(r"$2D$ surrogate posterior density")
|
| 115 |
+
ax_cont.set_xlabel(r"$\Omega_m$")
|
| 116 |
+
ax_cont.set_ylabel(r"$\sigma_8$")
|
| 117 |
+
ax11.set_title(r"$P(\sigma_8)$ marginal")
|
| 118 |
+
ax11.set_xlabel("density")
|
| 119 |
+
|
| 120 |
+
out = args.output or (paths[0].parent / ("triangle_" + "_".join(p.stem for p in paths) + ".png"))
|
| 121 |
+
out.parent.mkdir(parents=True, exist_ok=True)
|
| 122 |
+
fig.savefig(out, dpi=170, bbox_inches="tight")
|
| 123 |
+
plt.close(fig)
|
| 124 |
+
print("Saved", out)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|
cross_model/submit_vlb_1000grid.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Submit 1000×1000 grid VLB inference job.
|
| 4 |
+
Usage: python3 submit_vlb_1000grid.py [--dry-run]
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import subprocess
|
| 8 |
+
import argparse
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser(
|
| 13 |
+
description="Submit high-resolution VLB inference (1000×1000 grid)"
|
| 14 |
+
)
|
| 15 |
+
parser.add_argument("--dry-run", action="store_true",
|
| 16 |
+
help="Print command without submitting")
|
| 17 |
+
parser.add_argument("--grid-size", type=int, default=1000,
|
| 18 |
+
help="Grid resolution (default: 1000)")
|
| 19 |
+
parser.add_argument("--n-fields", type=int, default=9,
|
| 20 |
+
help="Number of test fields (default: 9)")
|
| 21 |
+
parser.add_argument("--job-name", default="vlb-infer-1000",
|
| 22 |
+
help="SLURM job name")
|
| 23 |
+
parser.add_argument("--time", default="24:00:00",
|
| 24 |
+
help="SLURM time limit (default: 24:00:00)")
|
| 25 |
+
parser.add_argument("--mem", default="32G",
|
| 26 |
+
help="Memory requirement (default: 32G)")
|
| 27 |
+
parser.add_argument("--batch-size", type=int, default=32,
|
| 28 |
+
help="Batch size (default: 32, reduce if OOM)")
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
script_dir = Path(__file__).parent.resolve()
|
| 32 |
+
checkpoint = script_dir / "notebook_model_weights/6param_best/best_model.pt"
|
| 33 |
+
training_args = script_dir / "notebook_model_weights/6param_best/args.json"
|
| 34 |
+
data_dir = script_dir.parent / "data/LH_data/params_6"
|
| 35 |
+
output_dir = script_dir / f"vlb_inference_outputs_{args.grid_size}grid"
|
| 36 |
+
|
| 37 |
+
# Validate paths
|
| 38 |
+
for path, name in [
|
| 39 |
+
(checkpoint, "checkpoint"),
|
| 40 |
+
(training_args, "training_args"),
|
| 41 |
+
(data_dir, "data_dir"),
|
| 42 |
+
]:
|
| 43 |
+
if not path.exists():
|
| 44 |
+
print(f"❌ Error: {name} not found at {path}")
|
| 45 |
+
return 1
|
| 46 |
+
|
| 47 |
+
cmd = [
|
| 48 |
+
"sbatch",
|
| 49 |
+
f"--job-name={args.job_name}",
|
| 50 |
+
f"--time={args.time}",
|
| 51 |
+
f"--mem={args.mem}",
|
| 52 |
+
f"--output=slurm-vlb-infer-{args.grid_size}-%j.out",
|
| 53 |
+
f"--error=slurm-vlb-infer-{args.grid_size}-%j.err",
|
| 54 |
+
str(script_dir / "run_vlb_inference_1000grid.sh"),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
# Build environment variables to override defaults in script
|
| 58 |
+
env_cmd = [
|
| 59 |
+
f"GRID_SIZE={args.grid_size}",
|
| 60 |
+
f"N_FIELDS={args.n_fields}",
|
| 61 |
+
f"BATCH_SIZE={args.batch_size}",
|
| 62 |
+
f"OUTPUT_DIR={output_dir}",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
full_cmd = env_cmd + cmd
|
| 66 |
+
|
| 67 |
+
print("=" * 60)
|
| 68 |
+
print("VLB Inference Submission — 1000×1000 Grid")
|
| 69 |
+
print("=" * 60)
|
| 70 |
+
print(f"Grid size: {args.grid_size}×{args.grid_size}")
|
| 71 |
+
print(f"Number of fields: {args.n_fields}")
|
| 72 |
+
print(f"Batch size: {args.batch_size}")
|
| 73 |
+
print(f"Output dir: {output_dir}")
|
| 74 |
+
print(f"Memory: {args.mem}")
|
| 75 |
+
print(f"Time limit: {args.time}")
|
| 76 |
+
print("=" * 60)
|
| 77 |
+
|
| 78 |
+
# Compute estimates
|
| 79 |
+
grid_points = args.grid_size ** 2
|
| 80 |
+
timesteps = 8
|
| 81 |
+
seeds = 4
|
| 82 |
+
n_forward_passes = grid_points * timesteps * seeds * args.n_fields
|
| 83 |
+
print(f"\nComputation scale:")
|
| 84 |
+
print(f" Total forward passes: {n_forward_passes:,}")
|
| 85 |
+
print(f" Per field: {n_forward_passes // args.n_fields:,}")
|
| 86 |
+
print(f" Est. time per field: ~{(n_forward_passes // args.n_fields) // 10_000:,} min")
|
| 87 |
+
print(f" Est. total time: ~{(n_forward_passes // 10_000_000):.1f}-{(n_forward_passes // 7_000_000):.1f} hours")
|
| 88 |
+
|
| 89 |
+
print(f"\nCommand: {' '.join(full_cmd)}\n")
|
| 90 |
+
|
| 91 |
+
if args.dry_run:
|
| 92 |
+
print("✓ Dry-run (not submitted)")
|
| 93 |
+
return 0
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
result = subprocess.run(full_cmd, check=True, capture_output=True, text=True)
|
| 97 |
+
print("✓ Job submitted!")
|
| 98 |
+
print(result.stdout)
|
| 99 |
+
return 0
|
| 100 |
+
except subprocess.CalledProcessError as e:
|
| 101 |
+
print(f"❌ Submission failed: {e}")
|
| 102 |
+
print(e.stderr)
|
| 103 |
+
return 1
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
exit(main())
|
scripts/shell/evaluate_conditional_lh6.sh
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=04:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_hi_lh6_eval
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-eval-%j.out
|
| 11 |
+
#SBATCH --error=slurm-eval-%j.err
|
| 12 |
+
|
| 13 |
+
# Evaluate conditional DDPM (6 CAMELS LH parameters).
|
| 14 |
+
# Submit:
|
| 15 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/evaluate_conditional_lh6.sh
|
| 16 |
+
#
|
| 17 |
+
# Optional overrides (example):
|
| 18 |
+
# sbatch --export=CHECKPOINT=/path/to/best_model.pt,OUTPUT_DIR=/path/to/eval_out evaluate_conditional_lh6.sh
|
| 19 |
+
|
| 20 |
+
REPO="/scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6"
|
| 21 |
+
cd "${REPO}" || exit 1
|
| 22 |
+
|
| 23 |
+
module load python/miniconda3-py3.12-usr
|
| 24 |
+
|
| 25 |
+
DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 26 |
+
# Default: trained run kept under april_26 (large artifacts not duplicated here).
|
| 27 |
+
CHECKPOINT="${CHECKPOINT:-/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt}"
|
| 28 |
+
OUTPUT_DIR="${OUTPUT_DIR:-${REPO}/evaluation_outputs_6param}"
|
| 29 |
+
TRAINING_ARGS="${TRAINING_ARGS:-}"
|
| 30 |
+
|
| 31 |
+
echo "==============================================="
|
| 32 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 33 |
+
echo "Job Name: ${SLURM_JOB_NAME:-evaluate_conditional_lh6}"
|
| 34 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 35 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 36 |
+
echo "Starting Time: $(date)"
|
| 37 |
+
echo "CHECKPOINT: ${CHECKPOINT}"
|
| 38 |
+
echo "DATA_DIR: ${DATA_DIR}"
|
| 39 |
+
echo "OUTPUT_DIR: ${OUTPUT_DIR}"
|
| 40 |
+
echo "==============================================="
|
| 41 |
+
|
| 42 |
+
EVAL_ARGS=(
|
| 43 |
+
python evaluate_conditional.py
|
| 44 |
+
--checkpoint "${CHECKPOINT}"
|
| 45 |
+
--data_dir "${DATA_DIR}"
|
| 46 |
+
--output_dir "${OUTPUT_DIR}"
|
| 47 |
+
--split test
|
| 48 |
+
--num_samples 8
|
| 49 |
+
--ddim_steps 50
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if [[ -n "${TRAINING_ARGS}" ]]; then
|
| 53 |
+
EVAL_ARGS+=(--training_args "${TRAINING_ARGS}")
|
| 54 |
+
fi
|
| 55 |
+
|
| 56 |
+
"${EVAL_ARGS[@]}"
|
| 57 |
+
|
| 58 |
+
echo "==============================================="
|
| 59 |
+
echo "Evaluation completed at: $(date)"
|
| 60 |
+
echo "Plots and evaluation_data.npz under: ${OUTPUT_DIR}"
|
| 61 |
+
echo "==============================================="
|
scripts/shell/plot_r2_cosmology_lhs.sh
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_r2_lhs
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-r2-lhs-%j.out
|
| 11 |
+
#SBATCH --error=slurm-r2-lhs-%j.err
|
| 12 |
+
|
| 13 |
+
# Latin-hypercube R² figure (plot_r2_cosmology_lhs.py): μ(P) and σ(P) vs (Ωm, σ8).
|
| 14 |
+
#
|
| 15 |
+
# Submit (full DDIM run — slow):
|
| 16 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/plot_r2_cosmology_lhs.sh
|
| 17 |
+
#
|
| 18 |
+
# Plot only from saved NPZ (fast):
|
| 19 |
+
# sbatch --export=FROM_NPZ=/path/to/r2_lhs_data.npz /scratch/.../plot_r2_cosmology_lhs.sh
|
| 20 |
+
#
|
| 21 |
+
# Optional env vars:
|
| 22 |
+
# CHECKPOINT, DATA_DIR, OUTPUT_PNG, SAVE_NPZ, LHS_N, MAPS_PER_POINT, DDIM_STEPS, SEED
|
| 23 |
+
|
| 24 |
+
REPO="/scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6"
|
| 25 |
+
cd "${REPO}" || exit 1
|
| 26 |
+
|
| 27 |
+
module load python/miniconda3-py3.12-usr
|
| 28 |
+
|
| 29 |
+
DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 30 |
+
CHECKPOINT="${CHECKPOINT:-/scratch/mrpcol001/Diffusion_job/april_26/ddpm_hi_lh6/outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt}"
|
| 31 |
+
OUTPUT_PNG="${OUTPUT_PNG:-${REPO}/ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png}"
|
| 32 |
+
FROM_NPZ="${FROM_NPZ:-}"
|
| 33 |
+
SAVE_NPZ="${SAVE_NPZ:-}"
|
| 34 |
+
LHS_N="${LHS_N:-50}"
|
| 35 |
+
MAPS_PER_POINT="${MAPS_PER_POINT:-15}"
|
| 36 |
+
DDIM_STEPS="${DDIM_STEPS:-50}"
|
| 37 |
+
SEED="${SEED:-42}"
|
| 38 |
+
|
| 39 |
+
echo "==============================================="
|
| 40 |
+
echo "Job ID: ${SLURM_JOB_ID:-local}"
|
| 41 |
+
echo "Job Name: ${SLURM_JOB_NAME:-plot_r2_cosmology_lhs}"
|
| 42 |
+
echo "Node: ${SLURM_NODELIST:-$(hostname)}"
|
| 43 |
+
echo "GPU: ${CUDA_VISIBLE_DEVICES:-n/a}"
|
| 44 |
+
echo "Starting Time: $(date)"
|
| 45 |
+
echo "OUTPUT_PNG: ${OUTPUT_PNG}"
|
| 46 |
+
echo "FROM_NPZ: ${FROM_NPZ:-(none — full compute)}"
|
| 47 |
+
echo "==============================================="
|
| 48 |
+
|
| 49 |
+
PY_ARGS=(
|
| 50 |
+
python plot_r2_cosmology_lhs.py
|
| 51 |
+
--output "${OUTPUT_PNG}"
|
| 52 |
+
--lhs-n "${LHS_N}"
|
| 53 |
+
--maps-per-point "${MAPS_PER_POINT}"
|
| 54 |
+
--ddim-steps "${DDIM_STEPS}"
|
| 55 |
+
--seed "${SEED}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if [[ -n "${FROM_NPZ}" ]]; then
|
| 59 |
+
PY_ARGS+=(--from-npz "${FROM_NPZ}")
|
| 60 |
+
else
|
| 61 |
+
PY_ARGS+=(--checkpoint "${CHECKPOINT}" --data-dir "${DATA_DIR}")
|
| 62 |
+
if [[ -n "${SAVE_NPZ}" ]]; then
|
| 63 |
+
PY_ARGS+=(--save-npz "${SAVE_NPZ}")
|
| 64 |
+
fi
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
"${PY_ARGS[@]}"
|
| 68 |
+
|
| 69 |
+
echo "==============================================="
|
| 70 |
+
echo "Finished at: $(date)"
|
| 71 |
+
echo "Figure: ${OUTPUT_PNG}"
|
| 72 |
+
echo "==============================================="
|
scripts/shell/train_conditional_lh6.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --account=l40sfree
|
| 3 |
+
#SBATCH --partition=l40s
|
| 4 |
+
#SBATCH --nodes=1
|
| 5 |
+
#SBATCH --ntasks=8
|
| 6 |
+
#SBATCH --gres=gpu:l40s:1
|
| 7 |
+
#SBATCH --time=48:00:00
|
| 8 |
+
#SBATCH --job-name=ddpm_hi_lh6
|
| 9 |
+
#SBATCH --mail-user=mrpcol001@myuct.ac.za
|
| 10 |
+
#SBATCH --output=slurm-%j.out
|
| 11 |
+
#SBATCH --error=slurm-%j.err
|
| 12 |
+
|
| 13 |
+
# Conditional DDPM training — 6 CAMELS LH parameters (ddpm_hi_lh6).
|
| 14 |
+
# Submit from anywhere:
|
| 15 |
+
# sbatch /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6/scripts/shell/train_conditional_lh6.sh
|
| 16 |
+
#
|
| 17 |
+
# Override data path (optional): any folder containing *_LH_6.npy and *_labels_LH.npy
|
| 18 |
+
# sbatch --export=DATA_DIR=/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6 train_conditional_lh6.sh
|
| 19 |
+
|
| 20 |
+
cd /scratch/mrpcol001/Diffusion_job/Models/6param_ddpm_hi_lh6
|
| 21 |
+
|
| 22 |
+
module load python/miniconda3-py3.12-usr
|
| 23 |
+
|
| 24 |
+
# Same LH_data layout as DDPM_HI_Emulation_improved (params_2 for 2 labels → params_6 here).
|
| 25 |
+
DATA_DIR="${DATA_DIR:-/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6}"
|
| 26 |
+
|
| 27 |
+
echo "==============================================="
|
| 28 |
+
echo "Job ID: $SLURM_JOB_ID"
|
| 29 |
+
echo "Job Name: $SLURM_JOB_NAME"
|
| 30 |
+
echo "Node: $SLURM_NODELIST"
|
| 31 |
+
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
| 32 |
+
echo "Starting Time: $(date)"
|
| 33 |
+
echo "Conditional diffusion training (ddpm_hi_lh6, 6 labels)"
|
| 34 |
+
echo "DATA_DIR: ${DATA_DIR}"
|
| 35 |
+
echo "==============================================="
|
| 36 |
+
|
| 37 |
+
python train_conditional.py \
|
| 38 |
+
--label_dim 6 \
|
| 39 |
+
--timesteps 1500 \
|
| 40 |
+
--use_ddim \
|
| 41 |
+
--ddim_steps 50 \
|
| 42 |
+
--normalize_labels \
|
| 43 |
+
--batch_size 8 \
|
| 44 |
+
--epochs 200 \
|
| 45 |
+
--lr 2e-4 \
|
| 46 |
+
--early_stop_patience 100 \
|
| 47 |
+
--sample_every 10 \
|
| 48 |
+
--base_channels 64 \
|
| 49 |
+
--channel_multipliers 1 2 4 8 \
|
| 50 |
+
--attention_levels 2 3 \
|
| 51 |
+
--data_dir "${DATA_DIR}" \
|
| 52 |
+
--output_dir outputs_conditional_6param \
|
| 53 |
+
--use_amp
|
| 54 |
+
|
| 55 |
+
# To resume: point --resume at checkpoints/checkpoint_epoch_N.pt and set --epochs to the
|
| 56 |
+
# new total; add --resume_refresh_scheduler if extending past the original epoch count.
|
| 57 |
+
|
| 58 |
+
echo "==============================================="
|
| 59 |
+
echo "Training completed at: $(date)"
|
| 60 |
+
echo "==============================================="
|
src/eval_model.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Helpers for ddpm_cond_eval.ipynb: R², P(k) on log N_HI fields, DDIM batches.
|
| 3 |
+
|
| 4 |
+
Uses evaluate_conditional for PowerSpectrum, sampling, and label z-scoring.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 11 |
+
|
| 12 |
+
import evaluate_conditional as ec
|
| 13 |
+
|
| 14 |
+
LO_LOG, HI_LOG = 14.0, 22.0
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def r2_score_1d(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
| 18 |
+
"""Univariate R² (same as sklearn for 1D arrays)."""
|
| 19 |
+
y_true = np.asarray(y_true, dtype=np.float64).ravel()
|
| 20 |
+
y_pred = np.asarray(y_pred, dtype=np.float64).ravel()
|
| 21 |
+
ss_res = np.sum((y_true - y_pred) ** 2)
|
| 22 |
+
ss_tot = np.sum((y_true - np.mean(y_true)) ** 2)
|
| 23 |
+
if ss_tot < 1e-30:
|
| 24 |
+
return 0.0 if ss_res < 1e-30 else float("-inf")
|
| 25 |
+
return float(1.0 - ss_res / ss_tot)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def cmap_r2_hiflow() -> LinearSegmentedColormap:
|
| 29 |
+
"""Green → yellow → red → purple → dark blue (HIFlow Fig. 7 style)."""
|
| 30 |
+
return LinearSegmentedColormap.from_list(
|
| 31 |
+
"r2_hiflow",
|
| 32 |
+
["#00a651", "#ffcc00", "#e74c3c", "#7d3c98", "#0d1b5c"],
|
| 33 |
+
N=256,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def images01_to_log_nhi(img01: np.ndarray, lo: float = LO_LOG, hi: float = HI_LOG) -> np.ndarray:
|
| 38 |
+
"""Maps in [0,1] linear in column density → log10(N_HI/cm^-2)."""
|
| 39 |
+
return lo + (hi - lo) * np.clip(img01, 0.0, 1.0).astype(np.float64)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def per_map_power_spectra_log(
|
| 43 |
+
images_01: np.ndarray, box_size: float = 25.0, lo: float = LO_LOG, hi: float = HI_LOG
|
| 44 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 45 |
+
"""Return (dk, Pk) with Pk shape (N, n_bins) using log10 N_HI field."""
|
| 46 |
+
logf = images01_to_log_nhi(images_01, lo, hi)
|
| 47 |
+
n = logf.shape[0]
|
| 48 |
+
npix = logf.shape[-1]
|
| 49 |
+
dl = box_size / npix
|
| 50 |
+
dk, _ = ec.PowerSpectrum(logf[0], N=npix, dl=dl)
|
| 51 |
+
pks = np.stack([ec.PowerSpectrum(logf[i], N=npix, dl=dl)[1] for i in range(n)])
|
| 52 |
+
return dk, pks
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def sample_batch(
|
| 56 |
+
model: torch.nn.Module,
|
| 57 |
+
labels_np: np.ndarray,
|
| 58 |
+
label_mean: np.ndarray,
|
| 59 |
+
label_std: np.ndarray,
|
| 60 |
+
normalize_labels: bool,
|
| 61 |
+
height: int,
|
| 62 |
+
width: int,
|
| 63 |
+
device: torch.device,
|
| 64 |
+
ddim_steps: int,
|
| 65 |
+
progress: bool,
|
| 66 |
+
) -> np.ndarray:
|
| 67 |
+
"""DDIM sample batch; labels_np shape (B, label_dim). mean/std same length as label_dim."""
|
| 68 |
+
labels_np = np.asarray(labels_np, dtype=np.float32)
|
| 69 |
+
mean = np.asarray(label_mean, dtype=np.float32)
|
| 70 |
+
std = np.asarray(label_std, dtype=np.float32)
|
| 71 |
+
if normalize_labels:
|
| 72 |
+
t = ec.prepare_labels_for_model(labels_np, mean, std).to(device)
|
| 73 |
+
else:
|
| 74 |
+
t = torch.from_numpy(labels_np).float().to(device)
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
out = model.sample(
|
| 77 |
+
labels=t,
|
| 78 |
+
channels=1,
|
| 79 |
+
height=height,
|
| 80 |
+
width=width,
|
| 81 |
+
device=device,
|
| 82 |
+
progress=progress,
|
| 83 |
+
use_ddim=True,
|
| 84 |
+
ddim_steps=ddim_steps,
|
| 85 |
+
)
|
| 86 |
+
return ec.from_model_output(out)
|
src/figure9_posterior.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Figure-9 style surrogate posteriors: build (Ωm, σ8) grids and log P(k) for observed maps.
|
| 3 |
+
|
| 4 |
+
Used by ddpm_cond_eval.ipynb. Sampling and P(k) live in eval_model.py.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
import eval_model as em
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_cosmo_grid(
|
| 14 |
+
g: int,
|
| 15 |
+
om_lo: float,
|
| 16 |
+
om_hi: float,
|
| 17 |
+
s8_lo: float,
|
| 18 |
+
s8_hi: float,
|
| 19 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 20 |
+
om_axis = np.linspace(om_lo, om_hi, g, dtype=np.float64)
|
| 21 |
+
s8_axis = np.linspace(s8_lo, s8_hi, g, dtype=np.float64)
|
| 22 |
+
og, sg = np.meshgrid(om_axis, s8_axis, indexing="ij")
|
| 23 |
+
grid_labels = np.stack([og.ravel(), sg.ravel()], axis=1).astype(np.float32)
|
| 24 |
+
return om_axis, s8_axis, og, sg, grid_labels
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def log_pk_observed(img01: np.ndarray, box_size: float, dk: np.ndarray) -> np.ndarray:
|
| 28 |
+
"""Single map → log P(k) on bins where dk > 0."""
|
| 29 |
+
_, pk = em.per_map_power_spectra_log(img01[np.newaxis, ...], box_size)
|
| 30 |
+
valid = dk > 0
|
| 31 |
+
if pk.shape[1] != len(dk):
|
| 32 |
+
raise ValueError("P(k) bin count mismatch vs dk")
|
| 33 |
+
return np.log(pk[0, valid] + 1e-30)
|
src/plot_r2_cosmology_lhs.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Reproduce the Latin-hypercube R² figure (μ(P) and σ(P)) in (Ωm, σ8) with a layout
|
| 4 |
+
that avoids colorbar / suptitle overlap.
|
| 5 |
+
|
| 6 |
+
Usage (full run — slow):
|
| 7 |
+
python plot_r2_cosmology_lhs.py --output ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png
|
| 8 |
+
|
| 9 |
+
Replay plot only from saved arrays:
|
| 10 |
+
python plot_r2_cosmology_lhs.py --from-npz r2_lhs_data.npz --output out.png
|
| 11 |
+
|
| 12 |
+
Defaults match ddpm_conditional / evaluate_conditional 6-param setup.
|
| 13 |
+
"""
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import matplotlib
|
| 20 |
+
|
| 21 |
+
matplotlib.use("Agg")
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
from matplotlib.cm import ScalarMappable
|
| 26 |
+
from matplotlib.colors import Normalize
|
| 27 |
+
from matplotlib.gridspec import GridSpec
|
| 28 |
+
|
| 29 |
+
import evaluate_conditional as ec
|
| 30 |
+
import eval_model as em
|
| 31 |
+
|
| 32 |
+
_SCRIPT_DIR = Path(__file__).resolve().parent
|
| 33 |
+
_DEFAULT_CKPT = _SCRIPT_DIR / "outputs_conditional_6param_20260413_132226/checkpoints/best_model.pt"
|
| 34 |
+
_DEFAULT_DATA = "/scratch/mrpcol001/Diffusion_job/data/LH_data/params_6"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def latin_hypercube_scaled(
|
| 38 |
+
n: int, lo: np.ndarray, hi: np.ndarray, rng: np.random.Generator
|
| 39 |
+
) -> np.ndarray:
|
| 40 |
+
"""n points in [lo, hi] per dimension (classic LHS)."""
|
| 41 |
+
d = int(lo.shape[0])
|
| 42 |
+
u = rng.random((n, d))
|
| 43 |
+
cut = np.linspace(0.0, 1.0, n + 1)
|
| 44 |
+
a, b = cut[:-1], cut[1:]
|
| 45 |
+
width = (b - a)[:, np.newaxis]
|
| 46 |
+
rd = a[:, np.newaxis] + u * width
|
| 47 |
+
for j in range(d):
|
| 48 |
+
rng.shuffle(rd[:, j])
|
| 49 |
+
span = (hi - lo).astype(np.float64)
|
| 50 |
+
return (lo + rd * span).astype(np.float32)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def compute_lhs_r2(
|
| 54 |
+
model: torch.nn.Module,
|
| 55 |
+
images_split: np.ndarray,
|
| 56 |
+
labels_split: np.ndarray,
|
| 57 |
+
label_mean: np.ndarray,
|
| 58 |
+
label_std: np.ndarray,
|
| 59 |
+
device: torch.device,
|
| 60 |
+
lhs_n: int,
|
| 61 |
+
maps_per_point: int,
|
| 62 |
+
batch_size: int,
|
| 63 |
+
box_size_mpc: float,
|
| 64 |
+
ddim_steps: int,
|
| 65 |
+
seed: int,
|
| 66 |
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 67 |
+
"""Returns lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b."""
|
| 68 |
+
lo_b = labels_split.min(axis=0)
|
| 69 |
+
hi_b = labels_split.max(axis=0)
|
| 70 |
+
rng = np.random.default_rng(seed)
|
| 71 |
+
lhs_pts = latin_hypercube_scaled(lhs_n, lo_b, hi_b, rng)
|
| 72 |
+
|
| 73 |
+
ldim = labels_split.shape[1]
|
| 74 |
+
h, w = int(images_split.shape[-2]), int(images_split.shape[-1])
|
| 75 |
+
bs = min(batch_size, maps_per_point)
|
| 76 |
+
npix = int(images_split.shape[-1])
|
| 77 |
+
dl = box_size_mpc / npix
|
| 78 |
+
|
| 79 |
+
def pk_stack(imgs: np.ndarray) -> np.ndarray:
|
| 80 |
+
return np.stack([ec.PowerSpectrum(im, N=npix, dl=dl)[1] for im in imgs], axis=0)
|
| 81 |
+
|
| 82 |
+
r2_mu_arr = np.full(lhs_n, np.nan, dtype=np.float64)
|
| 83 |
+
r2_sig_arr = np.full(lhs_n, np.nan, dtype=np.float64)
|
| 84 |
+
model.eval()
|
| 85 |
+
|
| 86 |
+
for ti in range(lhs_n):
|
| 87 |
+
theta = lhs_pts[ti]
|
| 88 |
+
dist = np.linalg.norm(labels_split - theta, axis=1)
|
| 89 |
+
nn_idx = np.argsort(dist)[:maps_per_point]
|
| 90 |
+
real_batch = images_split[nn_idx]
|
| 91 |
+
rep = np.tile(theta[None, :], (maps_per_point, 1))
|
| 92 |
+
gen_chunks = []
|
| 93 |
+
for j in range(0, maps_per_point, bs):
|
| 94 |
+
chunk = rep[j : j + bs]
|
| 95 |
+
bt = ec.prepare_labels_for_model(chunk, label_mean, label_std).to(device)
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
g = model.sample(
|
| 98 |
+
labels=bt,
|
| 99 |
+
channels=1,
|
| 100 |
+
height=h,
|
| 101 |
+
width=w,
|
| 102 |
+
device=device,
|
| 103 |
+
progress=False,
|
| 104 |
+
use_ddim=True,
|
| 105 |
+
ddim_steps=ddim_steps,
|
| 106 |
+
)
|
| 107 |
+
gen_chunks.append(ec.from_model_output(g))
|
| 108 |
+
gen_batch = np.concatenate(gen_chunks, axis=0)
|
| 109 |
+
|
| 110 |
+
pk_r = pk_stack(real_batch)
|
| 111 |
+
pk_g = pk_stack(gen_batch)
|
| 112 |
+
km = np.arange(pk_r.shape[1], dtype=int) > 0
|
| 113 |
+
mu_r, mu_g = pk_r.mean(axis=0), pk_g.mean(axis=0)
|
| 114 |
+
sr, sg = pk_r.std(axis=0), pk_g.std(axis=0)
|
| 115 |
+
r2_mu_arr[ti] = em.r2_score_1d(mu_r[km], mu_g[km])
|
| 116 |
+
r2_sig_arr[ti] = em.r2_score_1d(sr[km], sg[km])
|
| 117 |
+
|
| 118 |
+
return lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def plot_r2_cosmology_figure(
|
| 122 |
+
lhs_pts: np.ndarray,
|
| 123 |
+
r2_mu_arr: np.ndarray,
|
| 124 |
+
r2_sig_arr: np.ndarray,
|
| 125 |
+
lo_b: np.ndarray,
|
| 126 |
+
hi_b: np.ndarray,
|
| 127 |
+
out_path: Path,
|
| 128 |
+
r2_vmin: float = 0.90,
|
| 129 |
+
r2_vmax: float = 1.0,
|
| 130 |
+
lhs_n: int | None = None,
|
| 131 |
+
maps_per_point: int | None = None,
|
| 132 |
+
dpi: int = 160,
|
| 133 |
+
) -> None:
|
| 134 |
+
"""
|
| 135 |
+
Two-panel scatter in (Ωm, σ8) with a dedicated colorbar column (no overlap with heatmap).
|
| 136 |
+
"""
|
| 137 |
+
lhs_n = lhs_n if lhs_n is not None else len(r2_mu_arr)
|
| 138 |
+
maps_per_point = maps_per_point if maps_per_point is not None else 15
|
| 139 |
+
|
| 140 |
+
ldim = lhs_pts.shape[1]
|
| 141 |
+
om_plot = lhs_pts[:, 0]
|
| 142 |
+
s8_plot = lhs_pts[:, 1] if ldim >= 2 else np.zeros(lhs_n)
|
| 143 |
+
|
| 144 |
+
cmap = em.cmap_r2_hiflow()
|
| 145 |
+
norm = Normalize(vmin=r2_vmin, vmax=r2_vmax)
|
| 146 |
+
sm = ScalarMappable(norm=norm, cmap=cmap)
|
| 147 |
+
sm.set_array([])
|
| 148 |
+
|
| 149 |
+
fig = plt.figure(figsize=(11.5, 4.9))
|
| 150 |
+
# Left: data panels; narrow right strip: colorbar only (avoids fig.colorbar + tight_layout clash)
|
| 151 |
+
gs = GridSpec(
|
| 152 |
+
nrows=1,
|
| 153 |
+
ncols=3,
|
| 154 |
+
figure=fig,
|
| 155 |
+
width_ratios=[1.0, 1.0, 0.065],
|
| 156 |
+
wspace=0.26,
|
| 157 |
+
left=0.07,
|
| 158 |
+
right=0.98,
|
| 159 |
+
top=0.82,
|
| 160 |
+
bottom=0.14,
|
| 161 |
+
)
|
| 162 |
+
ax0 = fig.add_subplot(gs[0, 0])
|
| 163 |
+
ax1 = fig.add_subplot(gs[0, 1], sharey=ax0)
|
| 164 |
+
cax = fig.add_subplot(gs[0, 2])
|
| 165 |
+
|
| 166 |
+
pad_x = 0.02 * (float(hi_b[0] - lo_b[0]) + 1e-6)
|
| 167 |
+
ax0.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
|
| 168 |
+
ax1.set_xlim(float(lo_b[0]) - pad_x, float(hi_b[0]) + pad_x)
|
| 169 |
+
if ldim >= 2:
|
| 170 |
+
pad_y = 0.02 * (float(hi_b[1] - lo_b[1]) + 1e-6)
|
| 171 |
+
ax0.set_ylim(float(lo_b[1]) - pad_y, float(hi_b[1]) + pad_y)
|
| 172 |
+
|
| 173 |
+
for ax, r2v, subtitle in zip(
|
| 174 |
+
(ax0, ax1),
|
| 175 |
+
(r2_mu_arr, r2_sig_arr),
|
| 176 |
+
(r"$R^2$ for $\mu(P)$", r"$R^2$ for $\sigma(P)$"),
|
| 177 |
+
):
|
| 178 |
+
ok = np.isfinite(r2v)
|
| 179 |
+
ax.scatter(
|
| 180 |
+
om_plot[ok],
|
| 181 |
+
s8_plot[ok],
|
| 182 |
+
c=np.clip(r2v[ok], r2_vmin, r2_vmax),
|
| 183 |
+
cmap=cmap,
|
| 184 |
+
norm=norm,
|
| 185 |
+
s=52,
|
| 186 |
+
alpha=0.92,
|
| 187 |
+
edgecolors="k",
|
| 188 |
+
linewidths=0.35,
|
| 189 |
+
)
|
| 190 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=12)
|
| 191 |
+
ax.set_title(subtitle, fontsize=11)
|
| 192 |
+
ax.grid(True, alpha=0.25)
|
| 193 |
+
|
| 194 |
+
ax0.set_ylabel(r"$\sigma_8$", fontsize=12)
|
| 195 |
+
plt.setp(ax1.get_yticklabels(), visible=False)
|
| 196 |
+
|
| 197 |
+
cb = fig.colorbar(sm, cax=cax)
|
| 198 |
+
cb.set_label(r"$R^2$", fontsize=11)
|
| 199 |
+
cax.tick_params(labelsize=9)
|
| 200 |
+
|
| 201 |
+
fig.suptitle(
|
| 202 |
+
r"Visual summary of $R^2$ (CAMELS vs conditional DDPM) vs cosmology — "
|
| 203 |
+
+ f"{lhs_n} Latin Hypercube samples; {maps_per_point} maps / point",
|
| 204 |
+
fontsize=11,
|
| 205 |
+
fontweight="bold",
|
| 206 |
+
y=0.96,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
out_path = Path(out_path)
|
| 210 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 211 |
+
# Do not use bbox_inches="tight" — it rebalance axes and can squeeze the colorbar into the panels.
|
| 212 |
+
fig.savefig(out_path, dpi=dpi)
|
| 213 |
+
plt.close(fig)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _resolve_training_args(checkpoint: Path) -> Path | None:
|
| 217 |
+
run = checkpoint.parent.parent if checkpoint.parent.name == "checkpoints" else checkpoint.parent
|
| 218 |
+
for name in ("args.json", "args.txt"):
|
| 219 |
+
p = run / name
|
| 220 |
+
if p.is_file():
|
| 221 |
+
return p
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def parse_args() -> argparse.Namespace:
|
| 226 |
+
p = argparse.ArgumentParser(description="LHS R² cosmology figure (fixed colorbar layout)")
|
| 227 |
+
p.add_argument("--checkpoint", type=str, default=str(_DEFAULT_CKPT))
|
| 228 |
+
p.add_argument("--data-dir", type=str, default=_DEFAULT_DATA)
|
| 229 |
+
p.add_argument("--split", type=str, default="test", choices=("train", "val", "test"))
|
| 230 |
+
p.add_argument("--output", type=str, default=str(_SCRIPT_DIR / "ddpm_eval_notebook_out/r2_cosmology_lhs50_ddpm.png"))
|
| 231 |
+
p.add_argument("--from-npz", type=str, default=None, help="Load lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b")
|
| 232 |
+
p.add_argument("--save-npz", type=str, default=None, help="After compute, save arrays for --from-npz replot")
|
| 233 |
+
p.add_argument("--lhs-n", type=int, default=50)
|
| 234 |
+
p.add_argument("--maps-per-point", type=int, default=15)
|
| 235 |
+
p.add_argument("--batch-size", type=int, default=8)
|
| 236 |
+
p.add_argument("--ddim-steps", type=int, default=50)
|
| 237 |
+
p.add_argument("--box-size-mpc", type=float, default=25.0)
|
| 238 |
+
p.add_argument("--seed", type=int, default=42)
|
| 239 |
+
p.add_argument("--r2-vmin", type=float, default=0.90)
|
| 240 |
+
p.add_argument("--r2-vmax", type=float, default=1.0)
|
| 241 |
+
p.add_argument("--dpi", type=int, default=160)
|
| 242 |
+
return p.parse_args()
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def main() -> None:
|
| 246 |
+
args = parse_args()
|
| 247 |
+
out_path = Path(args.output)
|
| 248 |
+
|
| 249 |
+
if args.from_npz:
|
| 250 |
+
z = np.load(args.from_npz, allow_pickle=False)
|
| 251 |
+
lhs_pts = z["lhs_pts"]
|
| 252 |
+
r2_mu_arr = z["r2_mu_arr"]
|
| 253 |
+
r2_sig_arr = z["r2_sig_arr"]
|
| 254 |
+
lo_b = z["lo_b"]
|
| 255 |
+
hi_b = z["hi_b"]
|
| 256 |
+
else:
|
| 257 |
+
ckpt = Path(args.checkpoint).expanduser().resolve()
|
| 258 |
+
if not ckpt.is_file():
|
| 259 |
+
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
|
| 260 |
+
|
| 261 |
+
ta = _resolve_training_args(ckpt)
|
| 262 |
+
config: dict = {}
|
| 263 |
+
if ta is not None:
|
| 264 |
+
config = ec.load_training_config(str(ta))
|
| 265 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 266 |
+
model = ec.build_model(config, device)
|
| 267 |
+
ec.load_checkpoint(model, str(ckpt), device)
|
| 268 |
+
|
| 269 |
+
data_dir = Path(args.data_dir)
|
| 270 |
+
images_split, labels_split = ec.load_split(data_dir, args.split)
|
| 271 |
+
label_mean, label_std = ec.load_label_stats(data_dir)
|
| 272 |
+
|
| 273 |
+
lhs_pts, r2_mu_arr, r2_sig_arr, lo_b, hi_b = compute_lhs_r2(
|
| 274 |
+
model,
|
| 275 |
+
images_split,
|
| 276 |
+
labels_split,
|
| 277 |
+
label_mean,
|
| 278 |
+
label_std,
|
| 279 |
+
device,
|
| 280 |
+
lhs_n=args.lhs_n,
|
| 281 |
+
maps_per_point=args.maps_per_point,
|
| 282 |
+
batch_size=args.batch_size,
|
| 283 |
+
box_size_mpc=args.box_size_mpc,
|
| 284 |
+
ddim_steps=args.ddim_steps,
|
| 285 |
+
seed=args.seed,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if args.save_npz:
|
| 289 |
+
np.savez(
|
| 290 |
+
args.save_npz,
|
| 291 |
+
lhs_pts=lhs_pts,
|
| 292 |
+
r2_mu_arr=r2_mu_arr,
|
| 293 |
+
r2_sig_arr=r2_sig_arr,
|
| 294 |
+
lo_b=lo_b,
|
| 295 |
+
hi_b=hi_b,
|
| 296 |
+
)
|
| 297 |
+
print("Saved", args.save_npz)
|
| 298 |
+
|
| 299 |
+
plot_r2_cosmology_figure(
|
| 300 |
+
lhs_pts,
|
| 301 |
+
r2_mu_arr,
|
| 302 |
+
r2_sig_arr,
|
| 303 |
+
lo_b,
|
| 304 |
+
hi_b,
|
| 305 |
+
out_path,
|
| 306 |
+
r2_vmin=args.r2_vmin,
|
| 307 |
+
r2_vmax=args.r2_vmax,
|
| 308 |
+
lhs_n=args.lhs_n,
|
| 309 |
+
maps_per_point=args.maps_per_point,
|
| 310 |
+
dpi=args.dpi,
|
| 311 |
+
)
|
| 312 |
+
print("Saved", out_path.resolve())
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
main()
|
src/posterior_inference.py
ADDED
|
@@ -0,0 +1,895 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
posterior_inference.py — VLB-based cosmological inference (Mudur et al. 2023 §4 style).
|
| 4 |
+
Pure inference-time; frozen DDPM weights. Script lives next to diffusion_conditional.py.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import ast
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import matplotlib
|
| 19 |
+
matplotlib.use("Agg")
|
| 20 |
+
import matplotlib.gridspec as gridspec
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
import matplotlib.patheffects as mpathe
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
# ── Project imports ────────────────────────────────────────────────────────────
|
| 27 |
+
_ROOT = Path(__file__).resolve().parent
|
| 28 |
+
if (_ROOT / "diffusion_conditional.py").is_file():
|
| 29 |
+
sys.path.insert(0, str(_ROOT))
|
| 30 |
+
|
| 31 |
+
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
|
| 32 |
+
from unet_conditional import ConditionalUNet
|
| 33 |
+
|
| 34 |
+
plt.rcParams.update({
|
| 35 |
+
"figure.facecolor": "white", "axes.facecolor": "white",
|
| 36 |
+
"axes.edgecolor": "#222", "axes.linewidth": 0.7,
|
| 37 |
+
"axes.spines.top": False, "axes.spines.right": False,
|
| 38 |
+
"font.family": "DejaVu Sans", "font.size": 9.5,
|
| 39 |
+
"savefig.facecolor": "white",
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
REAL_COLOR = "#CC3333"
|
| 43 |
+
GEN_COLOR = "#2266BB"
|
| 44 |
+
SIGMA_LEVELS = [2.30, 6.17, 11.83]
|
| 45 |
+
SIGMA_COLORS = ["#1a5c9e", "#5590d0", "#99c0ea"]
|
| 46 |
+
SIGMA_LABELS = {2.30: r"$1\sigma$", 6.17: r"$2\sigma$", 11.83: r"$3\sigma$"}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_config(path: str) -> Dict:
|
| 50 |
+
p = Path(path)
|
| 51 |
+
if p.suffix == ".json":
|
| 52 |
+
with open(p) as f:
|
| 53 |
+
return json.load(f)
|
| 54 |
+
cfg = {}
|
| 55 |
+
with open(p) as f:
|
| 56 |
+
for line in f:
|
| 57 |
+
if ":" not in line:
|
| 58 |
+
continue
|
| 59 |
+
k, v = line.strip().split(":", 1)
|
| 60 |
+
try:
|
| 61 |
+
cfg[k.strip()] = ast.literal_eval(v.strip())
|
| 62 |
+
except Exception:
|
| 63 |
+
cfg[k.strip()] = v.strip()
|
| 64 |
+
return cfg
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def load_model(ckpt: str, cfg: Dict, device: torch.device) -> ConditionalDiffusionModel:
|
| 68 |
+
unet = ConditionalUNet(
|
| 69 |
+
in_channels=1, out_channels=1,
|
| 70 |
+
label_dim=int(cfg.get("label_dim", 2)),
|
| 71 |
+
base_channels=int(cfg.get("base_channels", 64)),
|
| 72 |
+
channel_multipliers=list(cfg.get("channel_multipliers", [1, 2, 4, 8])),
|
| 73 |
+
attention_levels=list(cfg.get("attention_levels", [2, 3])),
|
| 74 |
+
dropout=float(cfg.get("dropout", 0.1)),
|
| 75 |
+
)
|
| 76 |
+
diff = GaussianDiffusion(
|
| 77 |
+
timesteps=int(cfg.get("timesteps", 1500)),
|
| 78 |
+
beta_start=float(cfg.get("beta_start", 1e-4)),
|
| 79 |
+
beta_end=float(cfg.get("beta_end", 0.02)),
|
| 80 |
+
schedule_type=str(cfg.get("schedule_type", "linear")),
|
| 81 |
+
)
|
| 82 |
+
model = ConditionalDiffusionModel(unet, diff).to(device)
|
| 83 |
+
ck = torch.load(ckpt, map_location=device, weights_only=False)
|
| 84 |
+
if isinstance(ck, dict) and "ema_shadow" in ck:
|
| 85 |
+
cur = model.state_dict()
|
| 86 |
+
for k, v in ck["ema_shadow"].items():
|
| 87 |
+
if k in cur:
|
| 88 |
+
cur[k] = v
|
| 89 |
+
model.load_state_dict(cur)
|
| 90 |
+
print(" Loaded EMA weights")
|
| 91 |
+
elif isinstance(ck, dict) and "model_state_dict" in ck:
|
| 92 |
+
model.load_state_dict(ck["model_state_dict"])
|
| 93 |
+
else:
|
| 94 |
+
model.load_state_dict(ck)
|
| 95 |
+
model.eval()
|
| 96 |
+
for p in model.parameters():
|
| 97 |
+
p.requires_grad_(False)
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def load_test_data(
|
| 102 |
+
data_dir: str, n_fields: int, seed: int = 42
|
| 103 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 104 |
+
dp = Path(data_dir)
|
| 105 |
+
lsuf = "_2" if (dp / "train_labels_LH_2.npy").exists() else ""
|
| 106 |
+
isuf = "" if (dp / "train_LH.npy").exists() else "_6"
|
| 107 |
+
|
| 108 |
+
imgs = np.load(dp / f"test_LH{isuf}.npy").astype(np.float32)
|
| 109 |
+
labels = np.load(dp / f"test_labels_LH{lsuf}.npy").astype(np.float32)
|
| 110 |
+
tr_lab = np.load(dp / f"train_labels_LH{lsuf}.npy").astype(np.float32)
|
| 111 |
+
|
| 112 |
+
rng = np.random.default_rng(seed)
|
| 113 |
+
idx = rng.choice(len(imgs), n_fields, replace=False)
|
| 114 |
+
|
| 115 |
+
label_mu = tr_lab.mean(0)
|
| 116 |
+
label_std = np.where(tr_lab.std(0) == 0, 1.0, tr_lab.std(0))
|
| 117 |
+
return imgs[idx], labels[idx], label_mu, label_std
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def normal_kl(mean1, log_var1, mean2, log_var2):
|
| 121 |
+
return 0.5 * (
|
| 122 |
+
-1.0 + log_var2 - log_var1
|
| 123 |
+
+ torch.exp(log_var1 - log_var2)
|
| 124 |
+
+ ((mean1 - mean2) ** 2) * torch.exp(-log_var2)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _approx_standard_normal_cdf(x):
|
| 129 |
+
return 0.5 * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * x ** 3)))
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def discretised_gaussian_log_likelihood(x_0, mean, log_var):
|
| 133 |
+
centered_x = x_0 - mean
|
| 134 |
+
inv_stdv = torch.exp(-0.5 * log_var)
|
| 135 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
| 136 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
| 137 |
+
cdf_plus = _approx_standard_normal_cdf(plus_in)
|
| 138 |
+
cdf_min = _approx_standard_normal_cdf(min_in)
|
| 139 |
+
log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
|
| 140 |
+
log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
|
| 141 |
+
cdf_delta = (cdf_plus - cdf_min).clamp(min=1e-12)
|
| 142 |
+
log_probs = torch.where(
|
| 143 |
+
x_0 < -0.999,
|
| 144 |
+
log_cdf_plus,
|
| 145 |
+
torch.where(x_0 > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta)),
|
| 146 |
+
)
|
| 147 |
+
return log_probs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def predict_x_start_from_eps(diff: GaussianDiffusion, x_t: torch.Tensor,
|
| 151 |
+
t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
# Matches GaussianDiffusion._predict_xstart_from_noise (diffusion_conditional.py)
|
| 153 |
+
return (
|
| 154 |
+
diff._extract(diff.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
| 155 |
+
- diff._extract(diff.sqrt_recip_minus_one, t, x_t.shape) * eps
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def q_posterior_mean_var(diff: GaussianDiffusion, x_start: torch.Tensor,
|
| 160 |
+
x_t: torch.Tensor, t: torch.Tensor):
|
| 161 |
+
mean = (
|
| 162 |
+
diff._extract(diff.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 163 |
+
+ diff._extract(diff.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 164 |
+
)
|
| 165 |
+
var = diff._extract(diff.posterior_variance, t, x_t.shape)
|
| 166 |
+
log_var_c = diff._extract(diff.posterior_log_variance_clipped, t, x_t.shape)
|
| 167 |
+
return mean, var, log_var_c
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
@torch.no_grad()
|
| 171 |
+
def compute_L_t(
|
| 172 |
+
model: ConditionalDiffusionModel,
|
| 173 |
+
x_0: torch.Tensor,
|
| 174 |
+
labels_n: torch.Tensor,
|
| 175 |
+
t: int,
|
| 176 |
+
fixed_eps: torch.Tensor,
|
| 177 |
+
) -> torch.Tensor:
|
| 178 |
+
diff = model.diffusion
|
| 179 |
+
device = x_0.device
|
| 180 |
+
B = x_0.shape[0]
|
| 181 |
+
t_vec = torch.full((B,), t, device=device, dtype=torch.long)
|
| 182 |
+
|
| 183 |
+
if t == 0:
|
| 184 |
+
t1 = torch.full((B,), 1, device=device, dtype=torch.long)
|
| 185 |
+
ab1 = diff._extract(diff.alphas_cumprod, t1, x_0.shape)
|
| 186 |
+
x_1 = torch.sqrt(ab1) * x_0 + torch.sqrt(1.0 - ab1) * fixed_eps
|
| 187 |
+
eps_pred = model(x_1, t1, labels_n)
|
| 188 |
+
x_start_pred = predict_x_start_from_eps(diff, x_1, t1, eps_pred).clamp(-1, 1)
|
| 189 |
+
mean, _, log_var = q_posterior_mean_var(diff, x_start_pred, x_1, t1)
|
| 190 |
+
log_p = discretised_gaussian_log_likelihood(x_0, mean, log_var)
|
| 191 |
+
return -log_p.sum(dim=(1, 2, 3))
|
| 192 |
+
|
| 193 |
+
ab_t = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
|
| 194 |
+
x_t = torch.sqrt(ab_t) * x_0 + torch.sqrt(1.0 - ab_t) * fixed_eps
|
| 195 |
+
true_mean, _, true_log_var = q_posterior_mean_var(diff, x_0, x_t, t_vec)
|
| 196 |
+
eps_pred = model(x_t, t_vec, labels_n)
|
| 197 |
+
x_start_pred = predict_x_start_from_eps(diff, x_t, t_vec, eps_pred).clamp(-1, 1)
|
| 198 |
+
model_mean, _, model_log_var = q_posterior_mean_var(diff, x_start_pred, x_t, t_vec)
|
| 199 |
+
kl = normal_kl(true_mean, true_log_var, model_mean, model_log_var)
|
| 200 |
+
return kl.sum(dim=(1, 2, 3))
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def compute_L_T_analytic(diff: GaussianDiffusion, x_0: torch.Tensor) -> torch.Tensor:
|
| 205 |
+
T = diff.timesteps
|
| 206 |
+
t_vec = torch.full((x_0.shape[0],), T - 1, device=x_0.device, dtype=torch.long)
|
| 207 |
+
abar_T = diff._extract(diff.alphas_cumprod, t_vec, x_0.shape)
|
| 208 |
+
mean1 = torch.sqrt(abar_T) * x_0
|
| 209 |
+
log_var1 = torch.log((1.0 - abar_T).clamp(min=1e-30))
|
| 210 |
+
kl = normal_kl(mean1, log_var1, torch.zeros_like(mean1), torch.zeros_like(log_var1))
|
| 211 |
+
return kl.sum(dim=(1, 2, 3))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def build_eval_grid(
|
| 215 |
+
Om_true: float,
|
| 216 |
+
s8_true: float,
|
| 217 |
+
grid_size: int,
|
| 218 |
+
span: float = 0.1,
|
| 219 |
+
Om_range: Tuple[float, float] = (0.10, 0.50),
|
| 220 |
+
s8_range: Tuple[float, float] = (0.60, 1.00),
|
| 221 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 222 |
+
Om_lo = max(Om_true - span, Om_range[0])
|
| 223 |
+
Om_hi = min(Om_true + span, Om_range[1])
|
| 224 |
+
s8_lo = max(s8_true - span, s8_range[0])
|
| 225 |
+
s8_hi = min(s8_true + span, s8_range[1])
|
| 226 |
+
Om_1d = np.linspace(Om_lo, Om_hi, grid_size)
|
| 227 |
+
s8_1d = np.linspace(s8_lo, s8_hi, grid_size)
|
| 228 |
+
return Om_1d, s8_1d
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
@torch.no_grad()
|
| 232 |
+
def evaluate_vlb_surface(
|
| 233 |
+
model: ConditionalDiffusionModel,
|
| 234 |
+
x_0: torch.Tensor,
|
| 235 |
+
Om_grid: np.ndarray,
|
| 236 |
+
s8_grid: np.ndarray,
|
| 237 |
+
label_mu: np.ndarray,
|
| 238 |
+
label_std: np.ndarray,
|
| 239 |
+
t_values: List[int],
|
| 240 |
+
n_seeds: int = 4,
|
| 241 |
+
batch_size: int = 32,
|
| 242 |
+
label_dim: int = 2,
|
| 243 |
+
fixed_seed: int = 0,
|
| 244 |
+
device: Optional[torch.device] = None,
|
| 245 |
+
) -> Dict[int, np.ndarray]:
|
| 246 |
+
device = device or x_0.device
|
| 247 |
+
nO, nS = len(Om_grid), len(s8_grid)
|
| 248 |
+
n_pts = nO * nS
|
| 249 |
+
|
| 250 |
+
Omg, s8g = np.meshgrid(Om_grid, s8_grid, indexing="ij")
|
| 251 |
+
raw_labels = np.column_stack([Omg.ravel(), s8g.ravel()])
|
| 252 |
+
if label_dim > 2:
|
| 253 |
+
pad = np.zeros((n_pts, label_dim - 2), dtype=np.float32)
|
| 254 |
+
for i in range(label_dim - 2):
|
| 255 |
+
pad[:, i] = label_mu[2 + i]
|
| 256 |
+
raw_labels = np.concatenate([raw_labels, pad], axis=1)
|
| 257 |
+
norm_labels = (raw_labels - label_mu) / label_std
|
| 258 |
+
norm_labels_t = torch.from_numpy(norm_labels.astype(np.float32)).to(device)
|
| 259 |
+
|
| 260 |
+
L_surfaces = {t: np.zeros(n_pts, dtype=np.float64) for t in t_values}
|
| 261 |
+
|
| 262 |
+
H, W = x_0.shape[-2], x_0.shape[-1]
|
| 263 |
+
rng_torch = torch.Generator(device=device).manual_seed(fixed_seed)
|
| 264 |
+
seeds_eps = [
|
| 265 |
+
torch.randn(1, 1, H, W, generator=rng_torch, device=device)
|
| 266 |
+
for _ in range(n_seeds)
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
for _, fixed_eps in enumerate(seeds_eps):
|
| 270 |
+
for t in t_values:
|
| 271 |
+
for start in range(0, n_pts, batch_size):
|
| 272 |
+
end = min(start + batch_size, n_pts)
|
| 273 |
+
bsz = end - start
|
| 274 |
+
x_b = x_0.expand(bsz, -1, -1, -1)
|
| 275 |
+
lbl_b = norm_labels_t[start:end]
|
| 276 |
+
eps_b = fixed_eps.expand(bsz, -1, -1, -1)
|
| 277 |
+
L_t = compute_L_t(model, x_b, lbl_b, t=t, fixed_eps=eps_b)
|
| 278 |
+
L_surfaces[t][start:end] += L_t.cpu().numpy() / n_seeds
|
| 279 |
+
|
| 280 |
+
return {t: L_surfaces[t].reshape(nO, nS) for t in t_values}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def marginal_from_neg2dL(
|
| 284 |
+
neg2dL: np.ndarray, Om_grid: np.ndarray, s8_grid: np.ndarray
|
| 285 |
+
) -> Tuple[np.ndarray, np.ndarray, Tuple[float, float]]:
|
| 286 |
+
L = -0.5 * neg2dL
|
| 287 |
+
L = L - L.max()
|
| 288 |
+
P = np.exp(L)
|
| 289 |
+
Om_marginal = P.sum(axis=1)
|
| 290 |
+
Om_marginal /= Om_marginal.sum()
|
| 291 |
+
s8_marginal = P.sum(axis=0)
|
| 292 |
+
s8_marginal /= s8_marginal.sum()
|
| 293 |
+
Om_pred = float(Om_grid[np.argmax(Om_marginal)])
|
| 294 |
+
s8_pred = float(s8_grid[np.argmax(s8_marginal)])
|
| 295 |
+
return Om_marginal, s8_marginal, (Om_pred, s8_pred)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def credible_interval_68(values: np.ndarray, probs: np.ndarray) -> Tuple[float, float, float]:
|
| 299 |
+
cdf = np.cumsum(probs)
|
| 300 |
+
cdf /= cdf[-1]
|
| 301 |
+
median = float(np.interp(0.50, cdf, values))
|
| 302 |
+
lo = float(np.interp(0.16, cdf, values))
|
| 303 |
+
hi = float(np.interp(0.84, cdf, values))
|
| 304 |
+
return median, lo, hi
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def fig_contours_per_t(
|
| 308 |
+
surfaces: Dict[int, np.ndarray],
|
| 309 |
+
Om_grid: np.ndarray,
|
| 310 |
+
s8_grid: np.ndarray,
|
| 311 |
+
Om_true: float,
|
| 312 |
+
s8_true: float,
|
| 313 |
+
out_path: Path,
|
| 314 |
+
dpi: int = 200,
|
| 315 |
+
) -> None:
|
| 316 |
+
fig, ax = plt.subplots(figsize=(7, 6.5), dpi=dpi)
|
| 317 |
+
cmap = plt.cm.viridis
|
| 318 |
+
n_t = len(surfaces)
|
| 319 |
+
colors = cmap(np.linspace(0.05, 0.95, n_t))
|
| 320 |
+
|
| 321 |
+
for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
|
| 322 |
+
neg2dL = 2.0 * (L_surf - L_surf.min())
|
| 323 |
+
ax.contour(
|
| 324 |
+
Om_grid, s8_grid, neg2dL.T,
|
| 325 |
+
levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
|
| 326 |
+
)
|
| 327 |
+
ax.plot([], [], color=col, lw=1.8, label=f"t={t}")
|
| 328 |
+
|
| 329 |
+
ax.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
|
| 330 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=12)
|
| 331 |
+
ax.set_ylabel(r"$\sigma_8$", fontsize=12)
|
| 332 |
+
ax.set_title(
|
| 333 |
+
r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contour per timestep"
|
| 334 |
+
"\n(Mudur-style) smaller $t$ → tighter constraint",
|
| 335 |
+
fontweight="bold", fontsize=10,
|
| 336 |
+
)
|
| 337 |
+
ax.legend(fontsize=8, loc="best", ncol=2, framealpha=0.92)
|
| 338 |
+
ax.grid(alpha=0.18)
|
| 339 |
+
ax.set_xlim(Om_grid[0], Om_grid[-1])
|
| 340 |
+
ax.set_ylim(s8_grid[0], s8_grid[-1])
|
| 341 |
+
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
|
| 342 |
+
plt.close(fig)
|
| 343 |
+
print(f" Saved -> {out_path}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _L0_posterior_smoothed(
|
| 347 |
+
L0_surface: np.ndarray, smooth_sigma: float = 0.6,
|
| 348 |
+
):
|
| 349 |
+
from scipy.ndimage import gaussian_filter as gf
|
| 350 |
+
|
| 351 |
+
neg2dL = 2.0 * (L0_surface - L0_surface.min())
|
| 352 |
+
surface_sm = gf(neg2dL, sigma=smooth_sigma)
|
| 353 |
+
return surface_sm, neg2dL
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def draw_L0_posterior_main_panel(
|
| 357 |
+
ax,
|
| 358 |
+
surface_sm: np.ndarray,
|
| 359 |
+
Om_grid: np.ndarray,
|
| 360 |
+
s8_grid: np.ndarray,
|
| 361 |
+
Om_true: float,
|
| 362 |
+
s8_true: float,
|
| 363 |
+
Om_pred: float,
|
| 364 |
+
s8_pred: float,
|
| 365 |
+
*,
|
| 366 |
+
clabel_fontsize: float = 9.5,
|
| 367 |
+
marker_ms: float = 16,
|
| 368 |
+
) -> None:
|
| 369 |
+
ax.contourf(Om_grid, s8_grid, surface_sm.T, levels=60, cmap="Blues_r",
|
| 370 |
+
vmin=0, vmax=SIGMA_LEVELS[-1] * 3, extend="max", alpha=0.55)
|
| 371 |
+
for lv, co in zip(reversed(SIGMA_LEVELS), reversed(SIGMA_COLORS)):
|
| 372 |
+
ax.contourf(Om_grid, s8_grid, surface_sm.T,
|
| 373 |
+
levels=[0, lv], colors=[co], alpha=0.78)
|
| 374 |
+
|
| 375 |
+
cs = ax.contour(
|
| 376 |
+
Om_grid, s8_grid, surface_sm.T,
|
| 377 |
+
levels=SIGMA_LEVELS,
|
| 378 |
+
colors=["white", "white", "white"],
|
| 379 |
+
linewidths=[2.2, 1.6, 1.2],
|
| 380 |
+
linestyles=["-", "--", "-."],
|
| 381 |
+
)
|
| 382 |
+
ax.clabel(cs, fmt=SIGMA_LABELS, inline=True, fontsize=clabel_fontsize, colors="white")
|
| 383 |
+
|
| 384 |
+
ax.axvline(Om_true, color="red", lw=0.7, ls=":", alpha=0.6)
|
| 385 |
+
ax.axhline(s8_true, color="red", lw=0.7, ls=":", alpha=0.6)
|
| 386 |
+
ax.plot(Om_true, s8_true, "r+", ms=marker_ms, mew=2.5, zorder=6, label="True")
|
| 387 |
+
ax.plot(Om_pred, s8_pred, "w^", ms=max(6, marker_ms * 0.55), mew=1.2, zorder=6, label="MAP")
|
| 388 |
+
ax.set_xlim(Om_grid[0], Om_grid[-1])
|
| 389 |
+
ax.set_ylim(s8_grid[0], s8_grid[-1])
|
| 390 |
+
ax.grid(alpha=0.18)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def fig_posterior_L0_mosaic_3x3(
|
| 394 |
+
out_dir: Path,
|
| 395 |
+
n_fields: int,
|
| 396 |
+
out_path: Path,
|
| 397 |
+
mosaic_side_px: int = 10_000,
|
| 398 |
+
panel_inches: float = 4.0,
|
| 399 |
+
) -> None:
|
| 400 |
+
from matplotlib.patches import Patch
|
| 401 |
+
|
| 402 |
+
n_plot = min(n_fields, 9)
|
| 403 |
+
fig_side = panel_inches * 3
|
| 404 |
+
dpi = mosaic_side_px / fig_side
|
| 405 |
+
fig, axes = plt.subplots(
|
| 406 |
+
3, 3, figsize=(fig_side, fig_side), dpi=dpi,
|
| 407 |
+
squeeze=False,
|
| 408 |
+
)
|
| 409 |
+
for idx in range(9):
|
| 410 |
+
r, c = divmod(idx, 3)
|
| 411 |
+
ax = axes[r][c]
|
| 412 |
+
if idx >= n_plot:
|
| 413 |
+
ax.set_visible(False)
|
| 414 |
+
continue
|
| 415 |
+
nz = np.load(out_dir / f"field{idx:02d}_surfaces.npz")
|
| 416 |
+
L0 = np.asarray(nz["L_t0"])
|
| 417 |
+
Om_grid = np.asarray(nz["Om_grid"])
|
| 418 |
+
s8_grid = np.asarray(nz["s8_grid"])
|
| 419 |
+
Om_true = float(nz["Om_true"])
|
| 420 |
+
s8_true = float(nz["s8_true"])
|
| 421 |
+
surface_sm, _ = _L0_posterior_smoothed(L0, smooth_sigma=0.6)
|
| 422 |
+
_, _, (Om_pred, s8_pred) = marginal_from_neg2dL(surface_sm, Om_grid, s8_grid)
|
| 423 |
+
draw_L0_posterior_main_panel(
|
| 424 |
+
ax, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
|
| 425 |
+
clabel_fontsize=7.0, marker_ms=11,
|
| 426 |
+
)
|
| 427 |
+
if idx == 0:
|
| 428 |
+
legend_patches = [
|
| 429 |
+
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
|
| 430 |
+
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
|
| 431 |
+
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
|
| 432 |
+
]
|
| 433 |
+
hs, ls_ = ax.get_legend_handles_labels()
|
| 434 |
+
ax.legend(
|
| 435 |
+
handles=legend_patches + hs,
|
| 436 |
+
labels=[p.get_label() for p in legend_patches] + ls_,
|
| 437 |
+
fontsize=6, loc="upper right", framealpha=0.9,
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
leg = ax.get_legend()
|
| 441 |
+
if leg is not None:
|
| 442 |
+
leg.remove()
|
| 443 |
+
ax.set_title(
|
| 444 |
+
rf"field {idx}: $\Omega_m^{{\rm true}}={Om_true:.3f}$, $\sigma_8^{{\rm true}}={s8_true:.3f}$",
|
| 445 |
+
fontsize=8,
|
| 446 |
+
)
|
| 447 |
+
ax.set_xlabel(r"$\Omega_m$", fontsize=8)
|
| 448 |
+
ax.set_ylabel(r"$\sigma_8$", fontsize=8)
|
| 449 |
+
|
| 450 |
+
fig.suptitle(
|
| 451 |
+
r"VLB $L_0$ posterior (2D) — 9 test fields",
|
| 452 |
+
fontsize=11, fontweight="bold", y=0.995,
|
| 453 |
+
)
|
| 454 |
+
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
|
| 455 |
+
plt.close(fig)
|
| 456 |
+
print(f" Saved -> {out_path} (≈ {mosaic_side_px}×{mosaic_side_px} px)")
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def fig_main_posterior(
|
| 460 |
+
L0_surface: np.ndarray,
|
| 461 |
+
Om_grid: np.ndarray,
|
| 462 |
+
s8_grid: np.ndarray,
|
| 463 |
+
Om_true: float,
|
| 464 |
+
s8_true: float,
|
| 465 |
+
out_path: Path,
|
| 466 |
+
dpi: int = 200,
|
| 467 |
+
):
|
| 468 |
+
from matplotlib.patches import Patch
|
| 469 |
+
|
| 470 |
+
surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
|
| 471 |
+
|
| 472 |
+
Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
|
| 473 |
+
surface_sm, Om_grid, s8_grid
|
| 474 |
+
)
|
| 475 |
+
Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
|
| 476 |
+
s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
|
| 477 |
+
|
| 478 |
+
fig = plt.figure(figsize=(8.5, 8.5), dpi=dpi)
|
| 479 |
+
gs = gridspec.GridSpec(2, 2, width_ratios=[4, 1], height_ratios=[1, 4],
|
| 480 |
+
hspace=0.05, wspace=0.05,
|
| 481 |
+
left=0.10, right=0.95, top=0.95, bottom=0.08)
|
| 482 |
+
ax_main = fig.add_subplot(gs[1, 0])
|
| 483 |
+
ax_top = fig.add_subplot(gs[0, 0], sharex=ax_main)
|
| 484 |
+
ax_rt = fig.add_subplot(gs[1, 1], sharey=ax_main)
|
| 485 |
+
|
| 486 |
+
draw_L0_posterior_main_panel(
|
| 487 |
+
ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
|
| 491 |
+
ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
|
| 492 |
+
|
| 493 |
+
ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
|
| 494 |
+
ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
|
| 495 |
+
ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
|
| 496 |
+
ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
|
| 497 |
+
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
|
| 498 |
+
ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
|
| 499 |
+
ax_top.set_yticks([])
|
| 500 |
+
ax_top.tick_params(labelbottom=False)
|
| 501 |
+
ax_top.set_title(
|
| 502 |
+
rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
|
| 503 |
+
rf" (true: {Om_true:.3f})",
|
| 504 |
+
fontsize=9,
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
|
| 508 |
+
ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
|
| 509 |
+
ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
|
| 510 |
+
ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
|
| 511 |
+
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
|
| 512 |
+
ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
|
| 513 |
+
ax_rt.set_xticks([])
|
| 514 |
+
ax_rt.tick_params(labelleft=False)
|
| 515 |
+
ax_rt.set_ylabel(
|
| 516 |
+
rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
|
| 517 |
+
rf" (true: {s8_true:.3f})",
|
| 518 |
+
fontsize=9, rotation=270, labelpad=15,
|
| 519 |
+
)
|
| 520 |
+
ax_rt.yaxis.set_label_position("right")
|
| 521 |
+
|
| 522 |
+
legend_patches = [
|
| 523 |
+
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
|
| 524 |
+
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
|
| 525 |
+
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
|
| 526 |
+
]
|
| 527 |
+
hs, ls_ = ax_main.get_legend_handles_labels()
|
| 528 |
+
ax_main.legend(
|
| 529 |
+
handles=legend_patches + hs,
|
| 530 |
+
labels=[p.get_label() for p in legend_patches] + ls_,
|
| 531 |
+
fontsize=8, loc="upper right", framealpha=0.92,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
fig.suptitle(
|
| 535 |
+
r"VLB Posterior using $L_0$ — joint and marginal distributions",
|
| 536 |
+
fontsize=10, fontweight="bold", y=0.99,
|
| 537 |
+
)
|
| 538 |
+
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
|
| 539 |
+
plt.close(fig)
|
| 540 |
+
print(f" Saved -> {out_path}")
|
| 541 |
+
return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def fig_pred_vs_true(pred_results: List[Dict], out_path: Path, dpi: int = 200) -> None:
|
| 545 |
+
Om_true = np.array([r["Om_true"] for r in pred_results])
|
| 546 |
+
s8_true = np.array([r["s8_true"] for r in pred_results])
|
| 547 |
+
Om_pred = np.array([r["Om_pred"] for r in pred_results])
|
| 548 |
+
s8_pred = np.array([r["s8_pred"] for r in pred_results])
|
| 549 |
+
|
| 550 |
+
Om_err_lo = np.array([r["Om_pred"] - r["Om_lo"] for r in pred_results])
|
| 551 |
+
Om_err_hi = np.array([r["Om_hi"] - r["Om_pred"] for r in pred_results])
|
| 552 |
+
s8_err_lo = np.array([r["s8_pred"] - r["s8_lo"] for r in pred_results])
|
| 553 |
+
s8_err_hi = np.array([r["s8_hi"] - r["s8_pred"] for r in pred_results])
|
| 554 |
+
|
| 555 |
+
rmse_Om = np.sqrt(((Om_pred - Om_true) ** 2).mean())
|
| 556 |
+
rmse_s8 = np.sqrt(((s8_pred - s8_true) ** 2).mean())
|
| 557 |
+
|
| 558 |
+
fig, axes = plt.subplots(1, 2, figsize=(11, 5), dpi=dpi)
|
| 559 |
+
|
| 560 |
+
for ax, (true, pred, err_lo, err_hi, name, prange, rmse) in zip(axes, [
|
| 561 |
+
(Om_true, Om_pred, Om_err_lo, Om_err_hi, r"$\Omega_m$", (0.10, 0.50), rmse_Om),
|
| 562 |
+
(s8_true, s8_pred, s8_err_lo, s8_err_hi, r"$\sigma_8$", (0.60, 1.00), rmse_s8),
|
| 563 |
+
]):
|
| 564 |
+
ax.errorbar(
|
| 565 |
+
true, pred, yerr=[np.maximum(err_lo, 0), np.maximum(err_hi, 0)],
|
| 566 |
+
fmt="o", color=GEN_COLOR, ecolor=SIGMA_COLORS[1],
|
| 567 |
+
elinewidth=1.2, capsize=3, ms=6,
|
| 568 |
+
label="DDPM-VLB inference (68% CI)",
|
| 569 |
+
)
|
| 570 |
+
ax.plot(prange, prange, "k--", lw=1.0, alpha=0.5, label="Identity")
|
| 571 |
+
ax.set_xlabel(f"True {name}", fontsize=11)
|
| 572 |
+
ax.set_ylabel(f"Predicted {name}", fontsize=11)
|
| 573 |
+
ax.set_xlim(*prange)
|
| 574 |
+
ax.set_ylim(*prange)
|
| 575 |
+
ax.grid(alpha=0.2)
|
| 576 |
+
ax.legend(fontsize=9, loc="lower right")
|
| 577 |
+
ax.text(
|
| 578 |
+
0.04, 0.92, f"RMSE = {rmse:.4f}",
|
| 579 |
+
transform=ax.transAxes, fontsize=10,
|
| 580 |
+
bbox=dict(facecolor="white", edgecolor="#ccc", alpha=0.92, pad=4),
|
| 581 |
+
)
|
| 582 |
+
ax.set_title(f"{name}: predicted vs true", fontweight="bold", fontsize=10)
|
| 583 |
+
|
| 584 |
+
fig.suptitle(
|
| 585 |
+
"VLB Parameter Inference: predicted vs true\n"
|
| 586 |
+
r"Error bars = 68% CI from $L_0$ marginal posterior",
|
| 587 |
+
fontsize=10, fontweight="bold", y=1.01,
|
| 588 |
+
)
|
| 589 |
+
plt.tight_layout()
|
| 590 |
+
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
|
| 591 |
+
plt.close(fig)
|
| 592 |
+
print(f" Saved -> {out_path}")
|
| 593 |
+
print(f" RMSE: Omega_m={rmse_Om:.4f} sigma_8={rmse_s8:.4f}")
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def fig_posterior_and_contours_combined(
|
| 597 |
+
surfaces: Dict[int, np.ndarray],
|
| 598 |
+
L0_surface: np.ndarray,
|
| 599 |
+
Om_grid: np.ndarray,
|
| 600 |
+
s8_grid: np.ndarray,
|
| 601 |
+
Om_true: float,
|
| 602 |
+
s8_true: float,
|
| 603 |
+
out_path: Path,
|
| 604 |
+
dpi: int = 200,
|
| 605 |
+
) -> Tuple[float, float, Tuple[float, float], Tuple[float, float]]:
|
| 606 |
+
"""
|
| 607 |
+
Create a combined figure with contours_per_t on left and posterior on right.
|
| 608 |
+
"""
|
| 609 |
+
from matplotlib.patches import Patch
|
| 610 |
+
|
| 611 |
+
surface_sm, _ = _L0_posterior_smoothed(L0_surface, smooth_sigma=0.6)
|
| 612 |
+
Om_marg, s8_marg, (Om_pred, s8_pred) = marginal_from_neg2dL(
|
| 613 |
+
surface_sm, Om_grid, s8_grid
|
| 614 |
+
)
|
| 615 |
+
Om_med, Om_lo, Om_hi = credible_interval_68(Om_grid, Om_marg)
|
| 616 |
+
s8_med, s8_lo, s8_hi = credible_interval_68(s8_grid, s8_marg)
|
| 617 |
+
|
| 618 |
+
fig = plt.figure(figsize=(16, 7), dpi=dpi)
|
| 619 |
+
gs = gridspec.GridSpec(2, 4, width_ratios=[4, 0.3, 4, 1], height_ratios=[1, 4],
|
| 620 |
+
hspace=0.08, wspace=0.10,
|
| 621 |
+
left=0.08, right=0.96, top=0.94, bottom=0.08)
|
| 622 |
+
|
| 623 |
+
# Left panel: Contours per timestep
|
| 624 |
+
ax_contours = fig.add_subplot(gs[:, 0])
|
| 625 |
+
cmap = plt.cm.viridis
|
| 626 |
+
n_t = len(surfaces)
|
| 627 |
+
colors = cmap(np.linspace(0.05, 0.95, n_t))
|
| 628 |
+
|
| 629 |
+
for (t, L_surf), col in zip(sorted(surfaces.items()), colors):
|
| 630 |
+
neg2dL = 2.0 * (L_surf - L_surf.min())
|
| 631 |
+
ax_contours.contour(
|
| 632 |
+
Om_grid, s8_grid, neg2dL.T,
|
| 633 |
+
levels=[2.30], colors=[col], linewidths=[1.6], linestyles=["-"],
|
| 634 |
+
)
|
| 635 |
+
ax_contours.plot([], [], color=col, lw=1.8, label=f"t={t}")
|
| 636 |
+
|
| 637 |
+
ax_contours.plot(Om_true, s8_true, "r+", ms=18, mew=2.5, label="True", zorder=10)
|
| 638 |
+
ax_contours.set_xlabel(r"$\Omega_m$", fontsize=12)
|
| 639 |
+
ax_contours.set_ylabel(r"$\sigma_8$", fontsize=12)
|
| 640 |
+
ax_contours.set_title(
|
| 641 |
+
r"$-2\Delta\ln\hat{L}_t$ — $1\sigma$ contours per timestep",
|
| 642 |
+
fontweight="bold", fontsize=11,
|
| 643 |
+
)
|
| 644 |
+
ax_contours.legend(fontsize=8, loc="best", ncol=1, framealpha=0.92)
|
| 645 |
+
ax_contours.grid(alpha=0.18)
|
| 646 |
+
ax_contours.set_xlim(Om_grid[0], Om_grid[-1])
|
| 647 |
+
ax_contours.set_ylim(s8_grid[0], s8_grid[-1])
|
| 648 |
+
|
| 649 |
+
# Right panel: Posterior L_0 (similar to fig_main_posterior layout)
|
| 650 |
+
ax_main = fig.add_subplot(gs[1, 2])
|
| 651 |
+
ax_top = fig.add_subplot(gs[0, 2], sharex=ax_main)
|
| 652 |
+
ax_rt = fig.add_subplot(gs[1, 3], sharey=ax_main)
|
| 653 |
+
|
| 654 |
+
draw_L0_posterior_main_panel(
|
| 655 |
+
ax_main, surface_sm, Om_grid, s8_grid, Om_true, s8_true, Om_pred, s8_pred,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
ax_main.set_xlabel(r"$\Omega_m$", fontsize=11)
|
| 659 |
+
ax_main.set_ylabel(r"$\sigma_8$", fontsize=11)
|
| 660 |
+
|
| 661 |
+
ax_top.fill_between(Om_grid, 0, Om_marg, color=SIGMA_COLORS[1], alpha=0.6)
|
| 662 |
+
ax_top.plot(Om_grid, Om_marg, color=SIGMA_COLORS[0], lw=1.4)
|
| 663 |
+
ax_top.axvline(Om_true, color="red", lw=1.0, ls=":")
|
| 664 |
+
ax_top.axvline(Om_pred, color="white", lw=1.5, ls="--",
|
| 665 |
+
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
|
| 666 |
+
ax_top.axvspan(Om_lo, Om_hi, color=SIGMA_COLORS[0], alpha=0.18, label=r"68% CI")
|
| 667 |
+
ax_top.set_yticks([])
|
| 668 |
+
ax_top.tick_params(labelbottom=False)
|
| 669 |
+
ax_top.set_title(
|
| 670 |
+
rf"$\Omega_m={Om_med:.3f}^{{+{Om_hi-Om_med:.3f}}}_{{-{Om_med-Om_lo:.3f}}}$"
|
| 671 |
+
rf" (true: {Om_true:.3f})",
|
| 672 |
+
fontsize=9,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
ax_rt.fill_betweenx(s8_grid, 0, s8_marg, color=SIGMA_COLORS[1], alpha=0.6)
|
| 676 |
+
ax_rt.plot(s8_marg, s8_grid, color=SIGMA_COLORS[0], lw=1.4)
|
| 677 |
+
ax_rt.axhline(s8_true, color="red", lw=1.0, ls=":")
|
| 678 |
+
ax_rt.axhline(s8_pred, color="white", lw=1.5, ls="--",
|
| 679 |
+
path_effects=[mpathe.withStroke(linewidth=2.5, foreground="black")])
|
| 680 |
+
ax_rt.axhspan(s8_lo, s8_hi, color=SIGMA_COLORS[0], alpha=0.18)
|
| 681 |
+
ax_rt.set_xticks([])
|
| 682 |
+
ax_rt.tick_params(labelleft=False)
|
| 683 |
+
ax_rt.set_ylabel(
|
| 684 |
+
rf"$\sigma_8={s8_med:.3f}^{{+{s8_hi-s8_med:.3f}}}_{{-{s8_med-s8_lo:.3f}}}$"
|
| 685 |
+
rf" (true: {s8_true:.3f})",
|
| 686 |
+
fontsize=9, rotation=270, labelpad=15,
|
| 687 |
+
)
|
| 688 |
+
ax_rt.yaxis.set_label_position("right")
|
| 689 |
+
|
| 690 |
+
legend_patches = [
|
| 691 |
+
Patch(facecolor=SIGMA_COLORS[0], label=r"$1\sigma$"),
|
| 692 |
+
Patch(facecolor=SIGMA_COLORS[1], label=r"$2\sigma$"),
|
| 693 |
+
Patch(facecolor=SIGMA_COLORS[2], label=r"$3\sigma$"),
|
| 694 |
+
]
|
| 695 |
+
hs, ls_ = ax_main.get_legend_handles_labels()
|
| 696 |
+
ax_main.legend(
|
| 697 |
+
handles=legend_patches + hs,
|
| 698 |
+
labels=[p.get_label() for p in legend_patches] + ls_,
|
| 699 |
+
fontsize=8, loc="upper right", framealpha=0.92,
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
fig.suptitle(
|
| 703 |
+
r"VLB Inference: $L_t$ contours (left) and $L_0$ posterior (right)",
|
| 704 |
+
fontsize=12, fontweight="bold", y=0.99,
|
| 705 |
+
)
|
| 706 |
+
fig.savefig(out_path, bbox_inches="tight", dpi=dpi)
|
| 707 |
+
plt.close(fig)
|
| 708 |
+
print(f" Saved -> {out_path}")
|
| 709 |
+
return Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi)
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def parse_args() -> argparse.Namespace:
|
| 713 |
+
p = argparse.ArgumentParser(
|
| 714 |
+
description="VLB-based parameter inference for trained conditional DDPM.",
|
| 715 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 716 |
+
)
|
| 717 |
+
p.add_argument("--checkpoint", required=True)
|
| 718 |
+
p.add_argument("--training_args", default=None)
|
| 719 |
+
p.add_argument("--data_dir", default="./data/params_2")
|
| 720 |
+
p.add_argument("--output_dir", default="vlb_inference_outputs")
|
| 721 |
+
p.add_argument("--n_fields", type=int, default=9)
|
| 722 |
+
p.add_argument(
|
| 723 |
+
"--grid_size", type=int, default=10_000,
|
| 724 |
+
help="Ωm×σ8 evaluation grid resolution (each side). Values > 300 require "
|
| 725 |
+
"--allow_huge_grid (very long runs for large grid_size).",
|
| 726 |
+
)
|
| 727 |
+
p.add_argument(
|
| 728 |
+
"--allow_huge_grid", action="store_true",
|
| 729 |
+
help="Required when --grid_size > 300 (avoids accidental multi-week GPU jobs).",
|
| 730 |
+
)
|
| 731 |
+
p.add_argument(
|
| 732 |
+
"--mosaic_side_px", type=int, default=10_000,
|
| 733 |
+
help="Pixel width/height of posterior_L0_mosaic_3x3.png (square).",
|
| 734 |
+
)
|
| 735 |
+
p.add_argument(
|
| 736 |
+
"--mosaic_panel_inches", type=float, default=4.0,
|
| 737 |
+
help="Matplotlib size (inches) of each 3×3 panel; dpi = mosaic_side_px / (3× this).",
|
| 738 |
+
)
|
| 739 |
+
p.add_argument("--span", type=float, default=0.10)
|
| 740 |
+
p.add_argument("--t_subset", type=int, nargs="+",
|
| 741 |
+
default=[0, 1, 2, 5, 8, 10, 15, 20])
|
| 742 |
+
p.add_argument("--n_seeds", type=int, default=4)
|
| 743 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 744 |
+
p.add_argument("--device", default="auto")
|
| 745 |
+
p.add_argument("--seed", type=int, default=42)
|
| 746 |
+
p.add_argument("--dpi", type=int, default=200)
|
| 747 |
+
return p.parse_args()
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def autodetect_args() -> Optional[str]:
|
| 751 |
+
for pat in ["outputs_conditional_*/args.json", "outputs_conditional_*/args.txt"]:
|
| 752 |
+
cands = sorted(Path(".").glob(pat), key=os.path.getctime, reverse=True)
|
| 753 |
+
if cands:
|
| 754 |
+
return str(cands[0])
|
| 755 |
+
return None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def main() -> None:
|
| 759 |
+
args = parse_args()
|
| 760 |
+
if args.grid_size > 300 and not args.allow_huge_grid:
|
| 761 |
+
print(
|
| 762 |
+
"\nRefusing --grid_size={} (> 300) without --allow_huge_grid.\n"
|
| 763 |
+
"A 10_000×10_000 grid is roughly (200)^2 ≈ 40_000× more forward passes per\n"
|
| 764 |
+
"field than 50×50. For a quick run use e.g. --grid_size 50; for a high-res\n"
|
| 765 |
+
"summary figure use the default --mosaic_side_px without increasing grid_size.\n"
|
| 766 |
+
"To proceed anyway: add --allow_huge_grid\n".format(args.grid_size)
|
| 767 |
+
)
|
| 768 |
+
raise SystemExit(2)
|
| 769 |
+
torch.manual_seed(args.seed)
|
| 770 |
+
np.random.seed(args.seed)
|
| 771 |
+
|
| 772 |
+
device = (
|
| 773 |
+
torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 774 |
+
if args.device == "auto"
|
| 775 |
+
else torch.device(args.device)
|
| 776 |
+
)
|
| 777 |
+
print(f"\nDevice: {device}")
|
| 778 |
+
|
| 779 |
+
out = Path(args.output_dir)
|
| 780 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 781 |
+
|
| 782 |
+
if args.training_args is None:
|
| 783 |
+
args.training_args = autodetect_args()
|
| 784 |
+
if args.training_args is None:
|
| 785 |
+
raise FileNotFoundError("Cannot find args.json — pass --training_args")
|
| 786 |
+
print(f" Auto-detected args: {args.training_args}")
|
| 787 |
+
|
| 788 |
+
cfg = load_config(args.training_args)
|
| 789 |
+
print("\nLoading model ...")
|
| 790 |
+
model = load_model(args.checkpoint, cfg, device)
|
| 791 |
+
n_p = sum(p.numel() for p in model.parameters())
|
| 792 |
+
print(f" Parameters: {n_p:,} T={model.diffusion.timesteps}")
|
| 793 |
+
|
| 794 |
+
print(f"\nLoading {args.n_fields} test fields ...")
|
| 795 |
+
test_imgs, test_labels, label_mu, label_std = load_test_data(
|
| 796 |
+
args.data_dir, args.n_fields, seed=args.seed,
|
| 797 |
+
)
|
| 798 |
+
print(f" Image shape: {test_imgs.shape[1:]}")
|
| 799 |
+
print(f" Label dim: {test_labels.shape[1]}")
|
| 800 |
+
print(f" Label μ/σ: {label_mu} / {label_std}")
|
| 801 |
+
|
| 802 |
+
print(
|
| 803 |
+
f"\nEvaluating L_t on {args.grid_size}x{args.grid_size} grid for "
|
| 804 |
+
f"{len(args.t_subset)} timesteps × {args.n_seeds} seeds ..."
|
| 805 |
+
)
|
| 806 |
+
print(
|
| 807 |
+
f" -> {args.grid_size ** 2 * len(args.t_subset) * args.n_seeds:,} "
|
| 808 |
+
f"forward-pass groups per field (× seeds averaged)"
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
pred_results = []
|
| 812 |
+
label_dim = int(cfg.get("label_dim", 2))
|
| 813 |
+
|
| 814 |
+
for fi in range(args.n_fields):
|
| 815 |
+
Om_true = float(test_labels[fi, 0])
|
| 816 |
+
s8_true = float(test_labels[fi, 1])
|
| 817 |
+
print(f"\n [{fi+1}/{args.n_fields}] field with "
|
| 818 |
+
f"Om={Om_true:.3f}, s8={s8_true:.3f}")
|
| 819 |
+
|
| 820 |
+
x_0 = torch.from_numpy(test_imgs[fi:fi + 1] * 2.0 - 1.0).unsqueeze(1).to(device)
|
| 821 |
+
|
| 822 |
+
Om_grid, s8_grid = build_eval_grid(Om_true, s8_true, args.grid_size, args.span)
|
| 823 |
+
|
| 824 |
+
t_start = time.time()
|
| 825 |
+
surfaces = evaluate_vlb_surface(
|
| 826 |
+
model=model,
|
| 827 |
+
x_0=x_0,
|
| 828 |
+
Om_grid=Om_grid,
|
| 829 |
+
s8_grid=s8_grid,
|
| 830 |
+
label_mu=label_mu,
|
| 831 |
+
label_std=label_std,
|
| 832 |
+
t_values=args.t_subset,
|
| 833 |
+
n_seeds=args.n_seeds,
|
| 834 |
+
batch_size=args.batch_size,
|
| 835 |
+
label_dim=label_dim,
|
| 836 |
+
fixed_seed=args.seed + fi,
|
| 837 |
+
device=device,
|
| 838 |
+
)
|
| 839 |
+
elapsed = time.time() - t_start
|
| 840 |
+
print(f" Evaluation time: {elapsed:.1f}s")
|
| 841 |
+
|
| 842 |
+
np.savez(
|
| 843 |
+
out / f"field{fi:02d}_surfaces.npz",
|
| 844 |
+
**{f"L_t{t}": s for t, s in surfaces.items()},
|
| 845 |
+
Om_grid=Om_grid, s8_grid=s8_grid,
|
| 846 |
+
Om_true=Om_true, s8_true=s8_true,
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
if 0 in surfaces:
|
| 850 |
+
# Combined figure: contours_per_t + posterior on same plot
|
| 851 |
+
Om_pred, s8_pred, (Om_lo, Om_hi), (s8_lo, s8_hi) = fig_posterior_and_contours_combined(
|
| 852 |
+
surfaces, surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
|
| 853 |
+
out / f"field{fi:02d}_combined.png", dpi=args.dpi,
|
| 854 |
+
)
|
| 855 |
+
pred_results.append(dict(
|
| 856 |
+
Om_true=Om_true, s8_true=s8_true,
|
| 857 |
+
Om_pred=Om_pred, s8_pred=s8_pred,
|
| 858 |
+
Om_lo=Om_lo, Om_hi=Om_hi,
|
| 859 |
+
s8_lo=s8_lo, s8_hi=s8_hi,
|
| 860 |
+
))
|
| 861 |
+
|
| 862 |
+
# Also save individual figures for detailed inspection
|
| 863 |
+
fig_contours_per_t(
|
| 864 |
+
surfaces, Om_grid, s8_grid, Om_true, s8_true,
|
| 865 |
+
out / f"field{fi:02d}_contours_per_t.png", dpi=args.dpi,
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
if 0 in surfaces:
|
| 869 |
+
fig_main_posterior(
|
| 870 |
+
surfaces[0], Om_grid, s8_grid, Om_true, s8_true,
|
| 871 |
+
out / f"field{fi:02d}_posterior_L0.png", dpi=args.dpi,
|
| 872 |
+
)
|
| 873 |
+
|
| 874 |
+
if len(pred_results) >= 2:
|
| 875 |
+
fig_pred_vs_true(pred_results, out / "summary_pred_vs_true.png", dpi=args.dpi)
|
| 876 |
+
np.savez(
|
| 877 |
+
out / "summary.npz",
|
| 878 |
+
**{k: np.array([r[k] for r in pred_results])
|
| 879 |
+
for k in pred_results[0].keys()},
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
if args.n_fields >= 9 and all((out / f"field{i:02d}_surfaces.npz").is_file() for i in range(9)):
|
| 883 |
+
fig_posterior_L0_mosaic_3x3(
|
| 884 |
+
out, args.n_fields, out / "posterior_L0_mosaic_3x3.png",
|
| 885 |
+
mosaic_side_px=args.mosaic_side_px,
|
| 886 |
+
panel_inches=args.mosaic_panel_inches,
|
| 887 |
+
)
|
| 888 |
+
|
| 889 |
+
print(f"\nAll outputs -> {out.resolve()}/")
|
| 890 |
+
for f in sorted(out.glob("*.png")):
|
| 891 |
+
print(f" {f.name}")
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
if __name__ == "__main__":
|
| 895 |
+
main()
|
src/train_conditional.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for conditional diffusion on CAMELS LH (6 cosmological parameters).
|
| 3 |
+
|
| 4 |
+
Same training theory as DDPM_HI_Emulation_improved (2-label): DDPM noise prediction,
|
| 5 |
+
DDIM sampling, ConditionalUNet with time + label embeddings, label z-score from train split,
|
| 6 |
+
EMA, optional AMP, cosine LR, early stopping.
|
| 7 |
+
|
| 8 |
+
Changes from original:
|
| 9 |
+
- EMA weights are now applied before validation and sampling
|
| 10 |
+
- Training args are saved to args.txt for evaluation script
|
| 11 |
+
- Fixed --normalize_labels and --use_ddim flags (were un-disableable)
|
| 12 |
+
- Added mixed-precision (AMP) training support
|
| 13 |
+
- Fixed loss averaging to be per-sample rather than per-batch
|
| 14 |
+
- Added weights_only=True to torch.load for security
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import random
|
| 21 |
+
import time
|
| 22 |
+
|
| 23 |
+
import matplotlib.pyplot as plt
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.optim as optim
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from dataset_conditional import DEFAULT_DATA_DIR, get_conditional_dataloaders
|
| 30 |
+
from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion
|
| 31 |
+
from unet_conditional import ConditionalUNet
|
| 32 |
+
|
| 33 |
+
# Weights & Biases (optional)
|
| 34 |
+
try:
|
| 35 |
+
import wandb
|
| 36 |
+
|
| 37 |
+
WANDB_AVAILABLE = True
|
| 38 |
+
except ImportError:
|
| 39 |
+
WANDB_AVAILABLE = False
|
| 40 |
+
print("Warning: wandb not available. Install with: pip install wandb")
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class EMA:
|
| 44 |
+
"""Exponential Moving Average for model parameters"""
|
| 45 |
+
|
| 46 |
+
def __init__(self, model, decay=0.9999):
|
| 47 |
+
self.model = model
|
| 48 |
+
self.decay = decay
|
| 49 |
+
self.shadow = {}
|
| 50 |
+
for name, param in model.named_parameters():
|
| 51 |
+
if param.requires_grad:
|
| 52 |
+
self.shadow[name] = param.data.clone()
|
| 53 |
+
|
| 54 |
+
def update(self):
|
| 55 |
+
for name, param in self.model.named_parameters():
|
| 56 |
+
if param.requires_grad:
|
| 57 |
+
self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
|
| 58 |
+
|
| 59 |
+
def apply_shadow(self):
|
| 60 |
+
self.backup = {
|
| 61 |
+
name: param.data.clone() for name, param in self.model.named_parameters() if param.requires_grad
|
| 62 |
+
}
|
| 63 |
+
for name, param in self.model.named_parameters():
|
| 64 |
+
if param.requires_grad:
|
| 65 |
+
param.data = self.shadow[name]
|
| 66 |
+
|
| 67 |
+
def restore(self):
|
| 68 |
+
for name, param in self.model.named_parameters():
|
| 69 |
+
if param.requires_grad:
|
| 70 |
+
param.data = self.backup[name]
|
| 71 |
+
self.backup = {}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train_epoch(model, dataloader, optimizer, device, epoch, ema=None, use_wandb=False, scaler=None):
|
| 75 |
+
model.train()
|
| 76 |
+
total_loss = 0.0
|
| 77 |
+
total_samples = 0
|
| 78 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
|
| 79 |
+
|
| 80 |
+
for batch_idx, (images, labels) in enumerate(pbar):
|
| 81 |
+
images = images.to(device)
|
| 82 |
+
labels = labels.to(device)
|
| 83 |
+
batch_size = images.shape[0]
|
| 84 |
+
|
| 85 |
+
optimizer.zero_grad()
|
| 86 |
+
|
| 87 |
+
if scaler is not None:
|
| 88 |
+
with torch.amp.autocast("cuda"):
|
| 89 |
+
loss = model.get_loss(images, labels)
|
| 90 |
+
scaler.scale(loss).backward()
|
| 91 |
+
scaler.unscale_(optimizer)
|
| 92 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 93 |
+
scaler.step(optimizer)
|
| 94 |
+
scaler.update()
|
| 95 |
+
else:
|
| 96 |
+
loss = model.get_loss(images, labels)
|
| 97 |
+
loss.backward()
|
| 98 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 99 |
+
optimizer.step()
|
| 100 |
+
|
| 101 |
+
if ema is not None:
|
| 102 |
+
ema.update()
|
| 103 |
+
|
| 104 |
+
total_loss += loss.item() * batch_size
|
| 105 |
+
total_samples += batch_size
|
| 106 |
+
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 107 |
+
|
| 108 |
+
if use_wandb and batch_idx % 10 == 0:
|
| 109 |
+
wandb.log({"batch_loss": loss.item(), "epoch": epoch, "batch": epoch * len(dataloader) + batch_idx})
|
| 110 |
+
|
| 111 |
+
return total_loss / total_samples
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def validate(model, dataloader, device):
|
| 115 |
+
model.eval()
|
| 116 |
+
total_loss = 0.0
|
| 117 |
+
total_samples = 0
|
| 118 |
+
with torch.no_grad():
|
| 119 |
+
for images, labels in tqdm(dataloader, desc="Validating"):
|
| 120 |
+
images = images.to(device)
|
| 121 |
+
labels = labels.to(device)
|
| 122 |
+
batch_size = images.shape[0]
|
| 123 |
+
loss = model.get_loss(images, labels)
|
| 124 |
+
total_loss += loss.item() * batch_size
|
| 125 |
+
total_samples += batch_size
|
| 126 |
+
return total_loss / total_samples
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_checkpoint(model, optimizer, ema, epoch, loss, save_dir, is_best=False, last_improvement_epoch=None, scheduler=None):
|
| 130 |
+
checkpoint = {
|
| 131 |
+
"epoch": epoch,
|
| 132 |
+
"model_state_dict": model.state_dict(),
|
| 133 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 134 |
+
"loss": loss,
|
| 135 |
+
}
|
| 136 |
+
if ema is not None:
|
| 137 |
+
checkpoint["ema_shadow"] = ema.shadow
|
| 138 |
+
if last_improvement_epoch is not None:
|
| 139 |
+
checkpoint["last_improvement_epoch"] = last_improvement_epoch
|
| 140 |
+
if scheduler is not None:
|
| 141 |
+
checkpoint["scheduler_state_dict"] = scheduler.state_dict()
|
| 142 |
+
|
| 143 |
+
torch.save(checkpoint, os.path.join(save_dir, "checkpoint_latest.pt"))
|
| 144 |
+
if is_best:
|
| 145 |
+
torch.save(checkpoint, os.path.join(save_dir, "best_model.pt"))
|
| 146 |
+
print(f"Saved best model at epoch {epoch+1}")
|
| 147 |
+
|
| 148 |
+
if (epoch + 1) % 20 == 0:
|
| 149 |
+
torch.save(checkpoint, os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pt"))
|
| 150 |
+
|
| 151 |
+
print(f"Saved checkpoint at epoch {epoch+1}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def sample_images(model, diffusion, device, save_path, test_labels, ema=None, n_samples=8, epoch=0, use_ddim=True, ddim_steps=50, use_wandb=False):
|
| 155 |
+
if ema is not None:
|
| 156 |
+
ema.apply_shadow()
|
| 157 |
+
|
| 158 |
+
model.eval()
|
| 159 |
+
labels = test_labels[:n_samples].to(device)
|
| 160 |
+
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
samples = diffusion.sample(
|
| 163 |
+
model,
|
| 164 |
+
labels=labels,
|
| 165 |
+
channels=1,
|
| 166 |
+
height=256,
|
| 167 |
+
width=256,
|
| 168 |
+
device=device,
|
| 169 |
+
progress=True,
|
| 170 |
+
use_ddim=use_ddim,
|
| 171 |
+
ddim_steps=ddim_steps,
|
| 172 |
+
eta=0.0,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if ema is not None:
|
| 176 |
+
ema.restore()
|
| 177 |
+
|
| 178 |
+
n_cols = min(n_samples, 4)
|
| 179 |
+
n_rows = (n_samples + n_cols - 1) // n_cols
|
| 180 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 4.5 * n_rows))
|
| 181 |
+
if n_rows == 1 and n_cols == 1:
|
| 182 |
+
axes = np.array([[axes]])
|
| 183 |
+
elif n_rows == 1:
|
| 184 |
+
axes = axes[np.newaxis, :]
|
| 185 |
+
elif n_cols == 1:
|
| 186 |
+
axes = axes[:, np.newaxis]
|
| 187 |
+
for i in range(n_rows * n_cols):
|
| 188 |
+
ax = axes[i // n_cols, i % n_cols]
|
| 189 |
+
if i < n_samples:
|
| 190 |
+
img = samples[i, 0].cpu().numpy()
|
| 191 |
+
label_vals = labels[i].cpu().tolist()
|
| 192 |
+
label_str = ", ".join(f"{v:.2f}" for v in label_vals)
|
| 193 |
+
ax.imshow(img, vmin=-1, vmax=1)
|
| 194 |
+
ax.set_title(label_str, fontsize=10)
|
| 195 |
+
ax.axis("off")
|
| 196 |
+
|
| 197 |
+
plt.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=14)
|
| 198 |
+
plt.tight_layout()
|
| 199 |
+
plt.savefig(save_path, dpi=150, bbox_inches="tight")
|
| 200 |
+
|
| 201 |
+
if use_wandb:
|
| 202 |
+
wandb.log({"generated_samples": wandb.Image(save_path), "epoch": epoch})
|
| 203 |
+
plt.close()
|
| 204 |
+
print(f"Saved samples to {save_path}")
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def save_training_args(args, output_dir):
|
| 208 |
+
"""Save training arguments so the evaluation script can reconstruct the model."""
|
| 209 |
+
args_path = os.path.join(output_dir, "args.txt")
|
| 210 |
+
with open(args_path, "w", encoding="utf-8") as f:
|
| 211 |
+
for key, value in vars(args).items():
|
| 212 |
+
f.write(f"{key}: {value}\n")
|
| 213 |
+
args_json_path = os.path.join(output_dir, "args.json")
|
| 214 |
+
with open(args_json_path, "w", encoding="utf-8") as f:
|
| 215 |
+
json.dump(vars(args), f, indent=2)
|
| 216 |
+
print(f"Saved training args to {args_path} and {args_json_path}")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def main():
|
| 220 |
+
parser = argparse.ArgumentParser(description="Train conditional diffusion (LH 6-parameter)")
|
| 221 |
+
# Model
|
| 222 |
+
parser.add_argument("--label_dim", type=int, default=6)
|
| 223 |
+
parser.add_argument("--base_channels", type=int, default=64)
|
| 224 |
+
parser.add_argument("--channel_multipliers", type=int, nargs="+", default=[1, 2, 4, 8])
|
| 225 |
+
parser.add_argument("--attention_levels", type=int, nargs="+", default=[2, 3])
|
| 226 |
+
parser.add_argument("--dropout", type=float, default=0.1)
|
| 227 |
+
# Diffusion
|
| 228 |
+
parser.add_argument("--timesteps", type=int, default=1500)
|
| 229 |
+
parser.add_argument("--beta_start", type=float, default=1e-4)
|
| 230 |
+
parser.add_argument("--beta_end", type=float, default=0.02)
|
| 231 |
+
parser.add_argument("--schedule_type", type=str, default="linear")
|
| 232 |
+
# Training
|
| 233 |
+
parser.add_argument("--epochs", type=int, default=100)
|
| 234 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 235 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 236 |
+
parser.add_argument("--ema_decay", type=float, default=0.9999)
|
| 237 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 238 |
+
parser.add_argument("--early_stop_patience", type=int, default=30)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--use_amp",
|
| 241 |
+
action="store_true",
|
| 242 |
+
default=False,
|
| 243 |
+
help="Enable mixed-precision training (recommended for GPU)",
|
| 244 |
+
)
|
| 245 |
+
# Data
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--data_dir",
|
| 248 |
+
type=str,
|
| 249 |
+
default=DEFAULT_DATA_DIR,
|
| 250 |
+
help="Directory with *_LH_6.npy and *_labels_LH.npy (same rule as improved repo: e.g. .../LH_data/params_6)",
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument("--normalize_labels", action=argparse.BooleanOptionalAction, default=True)
|
| 253 |
+
# Output
|
| 254 |
+
parser.add_argument("--output_dir", type=str, default="outputs_conditional_6param")
|
| 255 |
+
parser.add_argument("--resume", type=str, default="")
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--resume_refresh_scheduler",
|
| 258 |
+
action="store_true",
|
| 259 |
+
help="On resume, rebuild cosine LR scheduler for --epochs (last_epoch=start-1) instead of loading saved scheduler; use when extending training beyond the original epoch count",
|
| 260 |
+
)
|
| 261 |
+
parser.add_argument("--sample_every", type=int, default=10)
|
| 262 |
+
parser.add_argument("--use_ddim", action=argparse.BooleanOptionalAction, default=True)
|
| 263 |
+
parser.add_argument("--ddim_steps", type=int, default=50)
|
| 264 |
+
# WandB
|
| 265 |
+
parser.add_argument("--use_wandb", action="store_true", default=False)
|
| 266 |
+
parser.add_argument("--wandb_project", type=str, default="ddpm_cosmology")
|
| 267 |
+
parser.add_argument("--wandb_entity", type=str, default="")
|
| 268 |
+
parser.add_argument("--wandb_run_name", type=str, default="")
|
| 269 |
+
|
| 270 |
+
args = parser.parse_args()
|
| 271 |
+
|
| 272 |
+
seed = 42
|
| 273 |
+
random.seed(seed)
|
| 274 |
+
np.random.seed(seed)
|
| 275 |
+
torch.manual_seed(seed)
|
| 276 |
+
if torch.cuda.is_available():
|
| 277 |
+
torch.cuda.manual_seed_all(seed)
|
| 278 |
+
torch.backends.cudnn.deterministic = True
|
| 279 |
+
torch.backends.cudnn.benchmark = False
|
| 280 |
+
|
| 281 |
+
use_wandb = args.use_wandb and WANDB_AVAILABLE
|
| 282 |
+
if use_wandb:
|
| 283 |
+
run_name = args.wandb_run_name or f"conditional_diffusion_{time.strftime('%Y%m%d_%H%M%S')}"
|
| 284 |
+
wandb.init(project=args.wandb_project, entity=args.wandb_entity or None, name=run_name, config=vars(args))
|
| 285 |
+
print(f"W&B run: {run_name}")
|
| 286 |
+
|
| 287 |
+
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
| 288 |
+
output_dir = f"{args.output_dir}_{timestamp}"
|
| 289 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 290 |
+
os.makedirs(os.path.join(output_dir, "checkpoints"), exist_ok=True)
|
| 291 |
+
os.makedirs(os.path.join(output_dir, "samples"), exist_ok=True)
|
| 292 |
+
|
| 293 |
+
save_training_args(args, output_dir)
|
| 294 |
+
|
| 295 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 296 |
+
print(f"Using device: {device}")
|
| 297 |
+
|
| 298 |
+
scaler = torch.amp.GradScaler("cuda") if args.use_amp and torch.cuda.is_available() else None
|
| 299 |
+
if scaler:
|
| 300 |
+
print("Mixed-precision training enabled (AMP)")
|
| 301 |
+
|
| 302 |
+
print("\nLoading data...")
|
| 303 |
+
train_loader, val_loader, test_loader = get_conditional_dataloaders(
|
| 304 |
+
data_dir=args.data_dir,
|
| 305 |
+
batch_size=args.batch_size,
|
| 306 |
+
num_workers=args.num_workers,
|
| 307 |
+
normalize_labels=args.normalize_labels,
|
| 308 |
+
label_dim=args.label_dim,
|
| 309 |
+
)
|
| 310 |
+
_, test_labels = next(iter(test_loader))
|
| 311 |
+
|
| 312 |
+
print("\nCreating model...")
|
| 313 |
+
unet = ConditionalUNet(
|
| 314 |
+
in_channels=1,
|
| 315 |
+
out_channels=1,
|
| 316 |
+
label_dim=args.label_dim,
|
| 317 |
+
base_channels=args.base_channels,
|
| 318 |
+
channel_multipliers=args.channel_multipliers,
|
| 319 |
+
attention_levels=args.attention_levels,
|
| 320 |
+
dropout=args.dropout,
|
| 321 |
+
)
|
| 322 |
+
diffusion = GaussianDiffusion(
|
| 323 |
+
timesteps=args.timesteps,
|
| 324 |
+
beta_start=args.beta_start,
|
| 325 |
+
beta_end=args.beta_end,
|
| 326 |
+
schedule_type=args.schedule_type,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
model = ConditionalDiffusionModel(unet, diffusion).to(device)
|
| 330 |
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 331 |
+
|
| 332 |
+
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
|
| 333 |
+
ema = EMA(model, decay=args.ema_decay)
|
| 334 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
|
| 335 |
+
|
| 336 |
+
start_epoch = 0
|
| 337 |
+
best_val_loss = float("inf")
|
| 338 |
+
last_improvement_epoch = -1
|
| 339 |
+
if args.resume:
|
| 340 |
+
print(f"Resuming from {args.resume}")
|
| 341 |
+
checkpoint = torch.load(args.resume, map_location=device, weights_only=False)
|
| 342 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 343 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 344 |
+
if "ema_shadow" in checkpoint:
|
| 345 |
+
ema.shadow = checkpoint["ema_shadow"]
|
| 346 |
+
start_epoch = checkpoint["epoch"] + 1
|
| 347 |
+
best_val_loss = checkpoint.get("loss", float("inf"))
|
| 348 |
+
last_improvement_epoch = checkpoint.get("last_improvement_epoch", -1)
|
| 349 |
+
if args.resume_refresh_scheduler:
|
| 350 |
+
scheduler = optim.lr_scheduler.CosineAnnealingLR(
|
| 351 |
+
optimizer, T_max=args.epochs, last_epoch=start_epoch - 1
|
| 352 |
+
)
|
| 353 |
+
print(
|
| 354 |
+
f"Rebuilt LR scheduler for extended run: T_max={args.epochs}, "
|
| 355 |
+
f"resume at epoch {start_epoch + 1} (last_epoch={start_epoch - 1})"
|
| 356 |
+
)
|
| 357 |
+
elif "scheduler_state_dict" in checkpoint:
|
| 358 |
+
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 359 |
+
|
| 360 |
+
print("\nStarting training...")
|
| 361 |
+
losses = {"train": [], "val": []}
|
| 362 |
+
|
| 363 |
+
for epoch in range(start_epoch, args.epochs):
|
| 364 |
+
train_loss = train_epoch(model, train_loader, optimizer, device, epoch, ema, use_wandb, scaler=scaler)
|
| 365 |
+
|
| 366 |
+
if ema is not None:
|
| 367 |
+
ema.apply_shadow()
|
| 368 |
+
val_loss = validate(model, val_loader, device)
|
| 369 |
+
if ema is not None:
|
| 370 |
+
ema.restore()
|
| 371 |
+
|
| 372 |
+
losses["train"].append(train_loss)
|
| 373 |
+
losses["val"].append(val_loss)
|
| 374 |
+
scheduler.step()
|
| 375 |
+
|
| 376 |
+
if use_wandb:
|
| 377 |
+
wandb.log(
|
| 378 |
+
{
|
| 379 |
+
"epoch": epoch + 1,
|
| 380 |
+
"train_loss": train_loss,
|
| 381 |
+
"val_loss": val_loss,
|
| 382 |
+
"learning_rate": optimizer.param_groups[0]["lr"],
|
| 383 |
+
}
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
print(
|
| 387 |
+
f"\nEpoch {epoch+1}/{args.epochs} | Train: {train_loss:.6f} | Val: {val_loss:.6f} | "
|
| 388 |
+
f"LR: {optimizer.param_groups[0]['lr']:.6e}"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
is_best = val_loss < best_val_loss
|
| 392 |
+
if is_best:
|
| 393 |
+
best_val_loss = val_loss
|
| 394 |
+
last_improvement_epoch = epoch
|
| 395 |
+
|
| 396 |
+
save_checkpoint(
|
| 397 |
+
model,
|
| 398 |
+
optimizer,
|
| 399 |
+
ema,
|
| 400 |
+
epoch,
|
| 401 |
+
val_loss,
|
| 402 |
+
os.path.join(output_dir, "checkpoints"),
|
| 403 |
+
is_best=is_best,
|
| 404 |
+
last_improvement_epoch=last_improvement_epoch,
|
| 405 |
+
scheduler=scheduler,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if epoch - last_improvement_epoch >= args.early_stop_patience:
|
| 409 |
+
print(f"Early stopping at epoch {epoch+1}")
|
| 410 |
+
break
|
| 411 |
+
|
| 412 |
+
if (epoch + 1) % args.sample_every == 0:
|
| 413 |
+
sample_path = os.path.join(output_dir, "samples", f"samples_epoch_{epoch+1}.png")
|
| 414 |
+
sample_images(
|
| 415 |
+
model,
|
| 416 |
+
diffusion,
|
| 417 |
+
device,
|
| 418 |
+
sample_path,
|
| 419 |
+
test_labels,
|
| 420 |
+
ema=ema,
|
| 421 |
+
epoch=epoch + 1,
|
| 422 |
+
use_ddim=args.use_ddim,
|
| 423 |
+
ddim_steps=args.ddim_steps,
|
| 424 |
+
use_wandb=use_wandb,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if (epoch + 1) % 5 == 0:
|
| 428 |
+
plt.figure(figsize=(10, 5))
|
| 429 |
+
plt.plot(losses["train"], label="Train Loss")
|
| 430 |
+
plt.plot(losses["val"], label="Val Loss")
|
| 431 |
+
plt.yscale("log")
|
| 432 |
+
plt.xlabel("Epoch")
|
| 433 |
+
plt.ylabel("Loss")
|
| 434 |
+
plt.title("Training Progress")
|
| 435 |
+
plt.legend()
|
| 436 |
+
plt.grid(True, alpha=0.3)
|
| 437 |
+
plt.savefig(os.path.join(output_dir, "losses.png"), dpi=150)
|
| 438 |
+
plt.close()
|
| 439 |
+
|
| 440 |
+
print(f"\nTraining completed! Best val loss: {best_val_loss:.6f}")
|
| 441 |
+
print(f"Results saved to: {output_dir}")
|
| 442 |
+
if use_wandb:
|
| 443 |
+
wandb.finish()
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
main()
|