my-diffusion-model / sample_inference_example.py
collins909's picture
Upload 4 files
e00756d verified
"""
inference_example.py
====================
Standalone script demonstrating how to use the deployed DDPM model.
After downloading from the Hub, run:
python inference_example.py
"""
import json
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
# Ensure local imports resolve
sys.path.insert(0, str(Path(__file__).parent))
from modeling_ddpm_camels import load_pretrained, generate
# ── Configuration ──────────────────────────────────────────────────────────
MODEL_DIR = Path(__file__).parent
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ── Load ───────────────────────────────────────────────────────────────────
print(f"Loading model from {MODEL_DIR} on {DEVICE} ...")
model, config = load_pretrained(MODEL_DIR, device=DEVICE)
print(f" Image size: {config[\"image_size\"]}")
print(f" Label dim: {config[\"label_dim\"]} ({config[\"label_names\"]})")
# ── Generate at 4 cosmologies ──────────────────────────────────────────────
raw_labels = torch.tensor([
[0.20, 0.95],
[0.30, 0.80],
[0.40, 0.70],
[0.50, 0.65],
], dtype=torch.float32)
if config["label_dim"] > 2:
# Pad with fiducial astrophysics (label_mu values of those dims)
pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0)
raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1)
print(f"\nGenerating samples ...")
with torch.no_grad():
out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50)
# Map [-1, 1] -> [0, 1] for visualisation
imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0]
# ── Display ────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5))
for ax, img, lbl in zip(axes, imgs, raw_labels):
ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1)
ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10)
ax.set_xticks([]); ax.set_yticks([])
plt.suptitle("Conditional DDPM samples β€” CAMELS HI fields", fontweight="bold")
plt.tight_layout()
plt.savefig("inference_example.png", dpi=150, bbox_inches="tight")
print(f"\nSaved -> inference_example.png")