DDPM-6param / inference_example.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified
#!/usr/bin/env python
# ---------------------------------------------------------------------------
# inference_example.py
#
# Self-contained example that downloads a conditional-DDPM checkpoint from
# the Hugging Face Hub and generates one HI map.
#
# Works for **both** uploaded models -- the script picks which one to load
# from a CLI argument:
#
# python inference_example.py --model 2param
# python inference_example.py --model 6param
# python inference_example.py --model 2param --repo myuser/my-fork
# python inference_example.py --model 6param --device cuda --ddim-steps 50
#
# The script:
# 1. Downloads `model.pt`, `args.json`, and the bundled src/*.py files.
# 2. Imports `ConditionalUNet` and `GaussianDiffusion` from the downloaded
# code (no need for a separate pip-installed package).
# 3. Rebuilds the model from `args.json` so weights and architecture
# cannot drift apart.
# 4. Samples one image with DDIM (or DDPM, with `--no-ddim`).
# 5. Saves a `.npy` of the raw [-1, 1] output and a PNG visualisation.
#
# This file is bundled inside each HF repo so users can grab a single script
# and immediately do inference.
# ---------------------------------------------------------------------------
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
# huggingface_hub is the only "extra" dependency; everything else (torch,
# numpy) is already required to run the model.
from huggingface_hub import hf_hub_download
# --------------------------------------------------------------------------
# Defaults -- adjust here or override via CLI flags
# --------------------------------------------------------------------------
DEFAULT_REPOS = {
"2param": "collins909/DDPM-2param",
"6param": "collins909/DDPM-6param",
}
# All files we expect to find in every uploaded repo. We download each one
# explicitly (rather than `snapshot_download`) so we can give a clear error
# message if anything is missing.
REQUIRED_FILES = [
"model.pt",
"args.json",
"src/__init__.py",
"src/unet_conditional.py",
"src/diffusion_conditional.py",
]
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Sample one HI map from the HF-hosted DDPM.")
p.add_argument(
"--model",
choices=sorted(DEFAULT_REPOS.keys()),
required=True,
help="Which model to download. Picks the matching default HF repo.",
)
p.add_argument(
"--repo",
default=None,
help="Override the HF repo id (default: see DEFAULT_REPOS in this file).",
)
p.add_argument(
"--device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="Torch device for sampling. Defaults to cuda if available else cpu.",
)
p.add_argument(
"--ddim-steps",
type=int,
default=50,
help="Number of DDIM steps (ignored when --no-ddim).",
)
p.add_argument(
"--no-ddim",
action="store_true",
help="Use the full DDPM sampler (slow, all 1500 steps) instead of DDIM.",
)
p.add_argument(
"--seed",
type=int,
default=0,
help="RNG seed for reproducible sampling.",
)
p.add_argument(
"--labels",
type=float,
nargs="+",
default=None,
help=(
"Conditioning vector (already z-scored). Length must match label_dim "
"(2 or 6). If omitted, an all-zeros vector is used (i.e. the training-set mean)."
),
)
p.add_argument(
"--output-dir",
type=Path,
default=Path("inference_outputs"),
help="Where to write the generated sample (.npy + .png).",
)
return p.parse_args()
def download_repo(repo_id: str) -> Path:
"""Download every required file from `repo_id`, return the local cache dir.
We rely on `hf_hub_download` to manage caching -- it stores files under
`~/.cache/huggingface/hub/` and returns the local path. We assume all the
required files end up in the same directory (which they do, modulo the
`src/` subfolder).
"""
print(f"[inference] Downloading {len(REQUIRED_FILES)} files from {repo_id}")
local_paths = [Path(hf_hub_download(repo_id, f)) for f in REQUIRED_FILES]
# The repo root in the local cache is the parent of `model.pt`.
repo_root = local_paths[0].parent
print(f"[inference] Cached at: {repo_root}")
return repo_root
def build_model(args_json: dict):
"""Re-create `ConditionalDiffusionModel` from the training args dict.
Importing the model classes from the just-downloaded `src/` package is
the safest way to avoid drift between weights and architecture: if the
repo ships a particular version of the U-Net code, that's the version
we use.
"""
from unet_conditional import ConditionalUNet
from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion
unet = ConditionalUNet(
in_channels=1,
out_channels=1,
label_dim=args_json["label_dim"],
base_channels=args_json["base_channels"],
channel_multipliers=tuple(args_json["channel_multipliers"]),
attention_levels=tuple(args_json["attention_levels"]),
dropout=args_json["dropout"],
)
diffusion = GaussianDiffusion(
timesteps=args_json["timesteps"],
beta_start=args_json["beta_start"],
beta_end=args_json["beta_end"],
schedule_type=args_json["schedule_type"],
)
return ConditionalDiffusionModel(unet, diffusion)
def load_weights(model: torch.nn.Module, ckpt_path: Path, device: str) -> None:
"""Load the state-dict produced by `train_conditional.py`.
The checkpoint is a dict with keys:
model_state_dict, optimizer_state_dict, ema_shadow, epoch, loss, ...
We only need `model_state_dict` for inference.
"""
# weights_only=False because the checkpoint also serialises optimizer
# state, EMA shadows, scheduler, etc. Safe here because we trust the
# source (the file came from our own training run on the cluster).
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
if "model_state_dict" not in ckpt:
raise KeyError(
f"{ckpt_path} doesn't contain 'model_state_dict' -- got keys: {list(ckpt)}"
)
model.load_state_dict(ckpt["model_state_dict"])
epoch = ckpt.get("epoch", "?")
loss = ckpt.get("loss", "?")
print(f"[inference] Loaded weights (epoch={epoch}, loss={loss})")
def save_outputs(sample: torch.Tensor, output_dir: Path, model_name: str) -> None:
"""Write the generated map to disk both as raw .npy and as a PNG preview."""
output_dir.mkdir(parents=True, exist_ok=True)
# `sample` is shape (1, 1, 256, 256) in [-1, 1]; squeeze and bring to CPU.
arr = sample.squeeze().detach().cpu().numpy()
npy_path = output_dir / f"sample_{model_name}.npy"
np.save(npy_path, arr)
print(f"[inference] Wrote {npy_path} shape={arr.shape} range=[{arr.min():.3f}, {arr.max():.3f}]")
# Optional PNG -- only if matplotlib is around. Keeps the hard dependency
# list short (matplotlib isn't strictly needed for the science workflow).
try:
import matplotlib.pyplot as plt
except ImportError:
print("[inference] matplotlib not installed -- skipping PNG preview.")
return
png_path = output_dir / f"sample_{model_name}.png"
plt.figure(figsize=(5, 5))
plt.imshow(arr, cmap="inferno", origin="lower")
plt.axis("off")
plt.title(f"DDPM {model_name} sample")
plt.tight_layout()
plt.savefig(png_path, dpi=120, bbox_inches="tight")
plt.close()
print(f"[inference] Wrote {png_path}")
def main() -> None:
args = parse_args()
repo_id = args.repo or DEFAULT_REPOS[args.model]
# ----------------------------------------------------------------------
# 1. Pull files from the Hub and make src/ importable
# ----------------------------------------------------------------------
repo_root = download_repo(repo_id)
sys.path.insert(0, str(repo_root / "src"))
# ----------------------------------------------------------------------
# 2. Rebuild the model from args.json
# ----------------------------------------------------------------------
with open(repo_root / "args.json") as f:
train_args = json.load(f)
expected_dim = train_args["label_dim"]
if expected_dim != (2 if args.model == "2param" else 6):
raise ValueError(
f"args.json says label_dim={expected_dim} but --model={args.model}; "
"did you point --repo at the wrong checkpoint?"
)
model = build_model(train_args).to(args.device)
load_weights(model, repo_root / "model.pt", args.device)
model.eval()
# ----------------------------------------------------------------------
# 3. Build the conditioning vector
# ----------------------------------------------------------------------
# By default we feed zeros, i.e. the training-set mean in the normalised
# space. To condition on physical (Ωm, σ8, ...) values, z-score them
# using the train-split statistics produced by `dataset_conditional.py`
# and pass the result via --labels.
if args.labels is None:
labels = torch.zeros((1, expected_dim), device=args.device)
print(f"[inference] Using zero (training-mean) conditioning, label_dim={expected_dim}")
else:
if len(args.labels) != expected_dim:
raise ValueError(
f"--labels has {len(args.labels)} entries but model expects {expected_dim}"
)
labels = torch.tensor([args.labels], dtype=torch.float32, device=args.device)
print(f"[inference] Using user-supplied labels: {args.labels}")
# ----------------------------------------------------------------------
# 4. Sample
# ----------------------------------------------------------------------
# Fix the RNG seed for reproducibility -- diffusion sampling is very
# sensitive to the initial Gaussian noise.
torch.manual_seed(args.seed)
if args.device.startswith("cuda"):
torch.cuda.manual_seed_all(args.seed)
use_ddim = not args.no_ddim
print(
f"[inference] Sampling 1 image with "
f"{'DDIM ' + str(args.ddim_steps) + ' steps' if use_ddim else 'DDPM ' + str(train_args['timesteps']) + ' steps'} "
f"on {args.device} ..."
)
with torch.no_grad():
sample = model.sample(
labels=labels,
channels=1,
height=256,
width=256,
device=args.device,
progress=True,
use_ddim=use_ddim,
ddim_steps=args.ddim_steps,
eta=0.0,
)
# ----------------------------------------------------------------------
# 5. Save outputs
# ----------------------------------------------------------------------
save_outputs(sample, args.output_dir, args.model)
print("[inference] Done.")
if __name__ == "__main__":
main()