""" inference_example.py ==================== Standalone script demonstrating how to use the deployed DDPM model. After downloading from the Hub, run: python inference_example.py """ import json import sys from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np import torch # Ensure local imports resolve sys.path.insert(0, str(Path(__file__).parent)) from modeling_ddpm_camels import load_pretrained, generate # ── Configuration ────────────────────────────────────────────────────────── MODEL_DIR = Path(__file__).parent DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ── Load ─────────────────────────────────────────────────────────────────── print(f"Loading model from {MODEL_DIR} on {DEVICE} ...") model, config = load_pretrained(MODEL_DIR, device=DEVICE) print(f" Image size: {config[\"image_size\"]}") print(f" Label dim: {config[\"label_dim\"]} ({config[\"label_names\"]})") # ── Generate at 4 cosmologies ────────────────────────────────────────────── raw_labels = torch.tensor([ [0.20, 0.95], [0.30, 0.80], [0.40, 0.70], [0.50, 0.65], ], dtype=torch.float32) if config["label_dim"] > 2: # Pad with fiducial astrophysics (label_mu values of those dims) pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0) raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1) print(f"\nGenerating samples ...") with torch.no_grad(): out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50) # Map [-1, 1] -> [0, 1] for visualisation imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0] # ── Display ──────────────────────────────────────────────────────────────── fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5)) for ax, img, lbl in zip(axes, imgs, raw_labels): ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1) ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10) ax.set_xticks([]); ax.set_yticks([]) plt.suptitle("Conditional DDPM samples — CAMELS HI fields", fontweight="bold") plt.tight_layout() plt.savefig("inference_example.png", dpi=150, bbox_inches="tight") print(f"\nSaved -> inference_example.png")