File size: 2,734 Bytes
e00756d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
"""
inference_example.py
====================
Standalone script demonstrating how to use the deployed DDPM model.
After downloading from the Hub, run:
    python inference_example.py
"""
import json
import sys
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch

# Ensure local imports resolve
sys.path.insert(0, str(Path(__file__).parent))

from modeling_ddpm_camels import load_pretrained, generate

# ── Configuration ──────────────────────────────────────────────────────────
MODEL_DIR = Path(__file__).parent
DEVICE    = "cuda" if torch.cuda.is_available() else "cpu"

# ── Load ───────────────────────────────────────────────────────────────────
print(f"Loading model from {MODEL_DIR} on {DEVICE} ...")
model, config = load_pretrained(MODEL_DIR, device=DEVICE)
print(f"  Image size: {config[\"image_size\"]}")
print(f"  Label dim:  {config[\"label_dim\"]} ({config[\"label_names\"]})")

# ── Generate at 4 cosmologies ──────────────────────────────────────────────
raw_labels = torch.tensor([
    [0.20, 0.95],
    [0.30, 0.80],
    [0.40, 0.70],
    [0.50, 0.65],
], dtype=torch.float32)

if config["label_dim"] > 2:
    # Pad with fiducial astrophysics (label_mu values of those dims)
    pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0)
    raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1)

print(f"\nGenerating samples ...")
with torch.no_grad():
    out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50)

# Map [-1, 1] -> [0, 1] for visualisation
imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0]

# ── Display ────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5))
for ax, img, lbl in zip(axes, imgs, raw_labels):
    ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1)
    ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10)
    ax.set_xticks([]); ax.set_yticks([])
plt.suptitle("Conditional DDPM samples β€” CAMELS HI fields", fontweight="bold")
plt.tight_layout()
plt.savefig("inference_example.png", dpi=150, bbox_inches="tight")
print(f"\nSaved -> inference_example.png")