| """ |
| inference_example.py |
| ==================== |
| Standalone script demonstrating how to use the deployed DDPM model. |
| After downloading from the Hub, run: |
| python inference_example.py |
| """ |
| import json |
| import sys |
| from pathlib import Path |
|
|
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
|
|
| |
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| from modeling_ddpm_camels import load_pretrained, generate |
|
|
| |
| MODEL_DIR = Path(__file__).parent |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| print(f"Loading model from {MODEL_DIR} on {DEVICE} ...") |
| model, config = load_pretrained(MODEL_DIR, device=DEVICE) |
| print(f" Image size: {config[\"image_size\"]}") |
| print(f" Label dim: {config[\"label_dim\"]} ({config[\"label_names\"]})") |
| |
| # ββ Generate at 4 cosmologies ββββββββββββββββββββββββββββββββββββββββββββββ |
| raw_labels = torch.tensor([ |
| [0.20, 0.95], |
| [0.30, 0.80], |
| [0.40, 0.70], |
| [0.50, 0.65], |
| ], dtype=torch.float32) |
| |
| if config["label_dim"] > 2: |
| # Pad with fiducial astrophysics (label_mu values of those dims) |
| pad = torch.tensor(config["label_mu"][2:], dtype=torch.float32).unsqueeze(0) |
| raw_labels = torch.cat([raw_labels, pad.expand(4, -1)], dim=1) |
| |
| print(f"\nGenerating samples ...") |
| with torch.no_grad(): |
| out = generate(model, config, raw_labels, device=DEVICE, ddim_steps=50) |
| |
| # Map [-1, 1] -> [0, 1] for visualisation |
| imgs = ((out.cpu().numpy() + 1) / 2).clip(0, 1)[:, 0] |
| |
| # ββ Display ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| fig, axes = plt.subplots(1, len(imgs), figsize=(3 * len(imgs), 3.5)) |
| for ax, img, lbl in zip(axes, imgs, raw_labels): |
| ax.imshow(img, cmap="magma", origin="lower", vmin=0, vmax=1) |
| ax.set_title(f"$\\Omega_m={lbl[0]:.2f}$, $\\sigma_8={lbl[1]:.2f}$", fontsize=10) |
| ax.set_xticks([]); ax.set_yticks([]) |
| plt.suptitle("Conditional DDPM samples β CAMELS HI fields", fontweight="bold") |
| plt.tight_layout() |
| plt.savefig("inference_example.png", dpi=150, bbox_inches="tight") |
| print(f"\nSaved -> inference_example.png") |
| |