File size: 7,334 Bytes
0d05ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c496462
 
0d05ab1
 
 
 
 
 
 
c496462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f513198
c496462
 
 
0d05ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f513198
0d05ab1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f513198
 
0d05ab1
 
 
 
 
 
f513198
0d05ab1
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
license: mit
library_name: pytorch
tags:
  - diffusion
  - ddpm
  - ddim
  - cosmology
  - astrophysics
  - camels
  - emulator
  - conditional-generation
pipeline_tag: unconditional-image-generation
---

# DDPM HI Emulator — 2 Parameter (CAMELS LH)

A conditional Denoising Diffusion Probabilistic Model (DDPM) that emulates
**neutral-hydrogen (HI) 2D maps** from the CAMELS Latin-Hypercube (LH)
simulation suite, conditioned on **two cosmological parameters**
(e.g. Ωm, σ8). Sampling supports both full DDPM and accelerated DDIM.

This checkpoint is **epoch 200** of the training run carried out under
`DDPM_HI_Emulation_improved/outputs_conditional_2label_20260408_125646/`.

## Files in this repo

**Top level**

| File | Purpose |
|------|---------|
| `model.pt` | PyTorch checkpoint (state-dict for `ConditionalDiffusionModel`) |
| `args.json` / `args.txt` | Training hyper-parameters and U-Net configuration |
| `config.json` | Architecture summary (for Hub discoverability) |
| `inference_example.py` | Runnable example: downloads weights and generates a sample |

**`src/` — per-model Python**

| File | Purpose |
|------|---------|
| `train_conditional.py` | Training entry point (`label_dim=2`) |
| `evaluate_conditional.py` | Held-out evaluation: samples + metrics |
| `ddim_investigation_2param.py` | DDIM-vs-DDPM sampler comparison study |
| `unet_conditional.py` | `ConditionalUNet` module |
| `diffusion_conditional.py` | `GaussianDiffusion` (DDPM + DDIM) and the wrapping `ConditionalDiffusionModel` |
| `dataset_conditional.py` | CAMELS LH dataset loader + label normalisation |

**`scripts/shell/` — SLURM launchers**

| File | Purpose |
|------|---------|
| `train_conditional.sh` | Submit a training job (`label_dim=2`) |
| `evaluate_conditional.sh` | Submit evaluation against the held-out test split |
| `run_ddim_investigation_2param.sh` | Launch the DDIM sampler study |

**`cross_model/` — posterior + comparison scripts that use BOTH models**

| File | Purpose |
|------|---------|
| `compare_posterior_inference.py` (+ `run_compare_posterior.sh`) | End-to-end posterior comparison between 2-param and 6-param emulators |
| `ddpm_posterior_corrected.py` (+ `scripts/run_ddpm_posterior_corrected.sh`) | Corrected DDPM posterior inference |
| `poster.py` / `check_poster_env.py` (+ `scripts/run_poster.sh`) | Posterior orchestration and environment check |
| `submit_vlb_1000grid.py` / `run_vlb_inference_*.sh` | Variational-lower-bound grid inference (200 / 1000 grid) |
| `scripts/compare_ddpm_models.py` (+ `run_ddpm_comparison.sh`) | DDPM-2 vs DDPM-6 comparison figures |
| `scripts/ddpm_posterior_six_anchors.py` (+ `run_ddpm_posterior_six_anchors.sh`) | Six-anchor posterior visualisation |
| `scripts/ddpm_figure6_integration.py`, `figure6_2409_style.py`, `run_ddpm_figure6_suite.py` (+ `run_ddpm_figure6.sh`) | Figure 6 generation pipeline |
| `scripts/ddpm_triangle_integration.py`, `triangle_plot_posterior.py` (+ `run_triangle_ddpm_both.sh`) | Triangle-plot posterior figures |
| `scripts/sigma_contour_utils.py` | Confidence-contour helper used by the figure scripts |
| `scripts/compare_ddpm_training_curves.py` | Parses SLURM logs for combined train/val loss plots |
| `cross_model/README.md` | How to point these scripts at locally-downloaded weights/data |

These cross-model scripts default to the original cluster paths (e.g.
`<CAMELS_LH_DATA_DIR>/params_2`). After downloading
this repo, supply `--bundle-2param`, `--bundle-6param`, `--data-2param`,
`--data-6param` to override.

## Architecture

Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from
`args.json`):

| Field | Value |
|-------|-------|
| `label_dim` | 2 |
| `base_channels` | 64 |
| `channel_multipliers` | [1, 2, 4, 8] |
| `attention_levels` | [2, 3] |
| `dropout` | 0.1 |
| `timesteps` | 1500 (linear β schedule: 1e-4 → 0.02) |
| EMA decay | 0.9999 |
| Sampler | DDIM, 50 steps (DDPM also supported) |
| Image size | 256 × 256, single channel |
| Image range | [-1, 1] (training data is rescaled by `x * 2 - 1`) |

Labels are z-scored using the **training-split** mean / std. The
`inference_example.py` shows how to recover this normalisation from the
CAMELS LH `params_2` dataset, or you can pass already-normalised conditioning
values directly.

## Quick start

```python
from huggingface_hub import hf_hub_download
import sys, torch, json
from pathlib import Path

# 1) Download all needed files
repo = "collins909/DDPM-2param"
ckpt_path  = hf_hub_download(repo, "model.pt")
args_path  = hf_hub_download(repo, "args.json")
# Pull the bundled source files so we can import the model classes.
for name in ("unet_conditional.py", "diffusion_conditional.py", "__init__.py"):
    hf_hub_download(repo, f"src/{name}")
sys.path.insert(0, str(Path(ckpt_path).parent / "src"))

from unet_conditional import ConditionalUNet
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel

# 2) Rebuild the model from args.json
args = json.loads(Path(args_path).read_text())
unet = ConditionalUNet(
    in_channels=1, out_channels=1,
    label_dim=args["label_dim"],
    base_channels=args["base_channels"],
    channel_multipliers=tuple(args["channel_multipliers"]),
    attention_levels=tuple(args["attention_levels"]),
    dropout=args["dropout"],
)
diffusion = GaussianDiffusion(
    timesteps=args["timesteps"],
    beta_start=args["beta_start"],
    beta_end=args["beta_end"],
    schedule_type=args["schedule_type"],
)
model = ConditionalDiffusionModel(unet, diffusion)

# 3) Load the checkpoint and sample
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Conditioning vector must be z-scored using training-split label statistics.
labels = torch.tensor([[0.0, 0.0]])  # placeholder; see inference_example.py
sample = model.sample(labels, channels=1, height=256, width=256,
                      device="cpu", use_ddim=True, ddim_steps=50)
# sample is in [-1, 1]; rescale to physical HI units as needed.
```

For an end-to-end runnable example (including label normalisation, GPU usage,
and image saving), see `inference_example.py` in this repo.

## Training data

Trained on **CAMELS LH** HI maps with 2-label conditioning. The exact data
layout used by `src/dataset_conditional.py` is:

```
<data_dir>/
  train_LH_2.npy, val_LH_2.npy, test_LH_2.npy
  train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy
```

Images are rescaled to `[-1, 1]`; labels are z-scored using train-split
statistics. Point your training/eval scripts at the local directory that contains those
files (e.g. via `--data_dir <CAMELS_LH_DATA_DIR>/params_2`).

## Intended use & limitations

- Intended for **research** on diffusion emulators for cosmological fields.
- The 2-label setup is a simplified subset of the full CAMELS LH parameter
  space; see the companion **6-parameter** model
  (`collins909/DDPM-6param`) for the full conditioning.
- Outputs are 256 × 256 single-channel maps in the model's normalised range.
  Apply the inverse of any data-pipeline preprocessing before physical
  interpretation.

## Citation

If you use this checkpoint, please cite the CAMELS project and the upstream
DDPM HI emulation work. (Citation block to be filled in once the
accompanying paper is published.)