File size: 2,734 Bytes
e00756d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | """
inference_example.py
====================
Standalone script demonstrating how to use the deployed DDPM model.
After downloading from the Hub, run:
python inference_example.py
"""
import json
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
# Ensure local imports resolve
sys.path.insert(0, str(Path(__file__).parent))
from modeling_ddpm_camels import load_pretrained, generate
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_DIR = Path(__file__).parent
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ββ Load βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print(f"Loading model from {MODEL_DIR} on {DEVICE} ...")
model, config = load_pretrained(MODEL_DIR, device=DEVICE)
print(f" Image size: {config[\"image_size\"]}")
print(f" Label dim: {config[\"label_dim\"]} ({config[\"label_names\"]})")
# ββ Generate at 4 cosmologies ββββββββββββββββββββββββββββββββββββββββββββββ
raw_labels = torch.tensor([
[0.20, 0.95],
[0.30, 0.80],
[0.40, 0.70],
[0.50, 0.65],
], dtype=torch.float32)
if config["label_dim"] > 2:
# Pad with fiducial astrophysics (label_mu values of those dims)
pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0)
raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1)
print(f"\nGenerating samples ...")
with torch.no_grad():
out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50)
# Map [-1, 1] -> [0, 1] for visualisation
imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0]
# ββ Display ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5))
for ax, img, lbl in zip(axes, imgs, raw_labels):
ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1)
ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10)
ax.set_xticks([]); ax.set_yticks([])
plt.suptitle("Conditional DDPM samples β CAMELS HI fields", fontweight="bold")
plt.tight_layout()
plt.savefig("inference_example.png", dpi=150, bbox_inches="tight")
print(f"\nSaved -> inference_example.png")
|