Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint) with full training/eval/posterior toolchain
1f3e7a2 verified | #!/usr/bin/env python | |
| # --------------------------------------------------------------------------- | |
| # inference_example.py | |
| # | |
| # Self-contained example that downloads a conditional-DDPM checkpoint from | |
| # the Hugging Face Hub and generates one HI map. | |
| # | |
| # Works for **both** uploaded models -- the script picks which one to load | |
| # from a CLI argument: | |
| # | |
| # python inference_example.py --model 2param | |
| # python inference_example.py --model 6param | |
| # python inference_example.py --model 2param --repo myuser/my-fork | |
| # python inference_example.py --model 6param --device cuda --ddim-steps 50 | |
| # | |
| # The script: | |
| # 1. Downloads `model.pt`, `args.json`, and the bundled src/*.py files. | |
| # 2. Imports `ConditionalUNet` and `GaussianDiffusion` from the downloaded | |
| # code (no need for a separate pip-installed package). | |
| # 3. Rebuilds the model from `args.json` so weights and architecture | |
| # cannot drift apart. | |
| # 4. Samples one image with DDIM (or DDPM, with `--no-ddim`). | |
| # 5. Saves a `.npy` of the raw [-1, 1] output and a PNG visualisation. | |
| # | |
| # This file is bundled inside each HF repo so users can grab a single script | |
| # and immediately do inference. | |
| # --------------------------------------------------------------------------- | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| # huggingface_hub is the only "extra" dependency; everything else (torch, | |
| # numpy) is already required to run the model. | |
| from huggingface_hub import hf_hub_download | |
| # -------------------------------------------------------------------------- | |
| # Defaults -- adjust here or override via CLI flags | |
| # -------------------------------------------------------------------------- | |
| DEFAULT_REPOS = { | |
| "2param": "collins909/DDPM-2param", | |
| "6param": "collins909/DDPM-6param", | |
| } | |
| # All files we expect to find in every uploaded repo. We download each one | |
| # explicitly (rather than `snapshot_download`) so we can give a clear error | |
| # message if anything is missing. | |
| REQUIRED_FILES = [ | |
| "model.pt", | |
| "args.json", | |
| "src/__init__.py", | |
| "src/unet_conditional.py", | |
| "src/diffusion_conditional.py", | |
| ] | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Sample one HI map from the HF-hosted DDPM.") | |
| p.add_argument( | |
| "--model", | |
| choices=sorted(DEFAULT_REPOS.keys()), | |
| required=True, | |
| help="Which model to download. Picks the matching default HF repo.", | |
| ) | |
| p.add_argument( | |
| "--repo", | |
| default=None, | |
| help="Override the HF repo id (default: see DEFAULT_REPOS in this file).", | |
| ) | |
| p.add_argument( | |
| "--device", | |
| default="cuda" if torch.cuda.is_available() else "cpu", | |
| help="Torch device for sampling. Defaults to cuda if available else cpu.", | |
| ) | |
| p.add_argument( | |
| "--ddim-steps", | |
| type=int, | |
| default=50, | |
| help="Number of DDIM steps (ignored when --no-ddim).", | |
| ) | |
| p.add_argument( | |
| "--no-ddim", | |
| action="store_true", | |
| help="Use the full DDPM sampler (slow, all 1500 steps) instead of DDIM.", | |
| ) | |
| p.add_argument( | |
| "--seed", | |
| type=int, | |
| default=0, | |
| help="RNG seed for reproducible sampling.", | |
| ) | |
| p.add_argument( | |
| "--labels", | |
| type=float, | |
| nargs="+", | |
| default=None, | |
| help=( | |
| "Conditioning vector (already z-scored). Length must match label_dim " | |
| "(2 or 6). If omitted, an all-zeros vector is used (i.e. the training-set mean)." | |
| ), | |
| ) | |
| p.add_argument( | |
| "--output-dir", | |
| type=Path, | |
| default=Path("inference_outputs"), | |
| help="Where to write the generated sample (.npy + .png).", | |
| ) | |
| return p.parse_args() | |
| def download_repo(repo_id: str) -> Path: | |
| """Download every required file from `repo_id`, return the local cache dir. | |
| We rely on `hf_hub_download` to manage caching -- it stores files under | |
| `~/.cache/huggingface/hub/` and returns the local path. We assume all the | |
| required files end up in the same directory (which they do, modulo the | |
| `src/` subfolder). | |
| """ | |
| print(f"[inference] Downloading {len(REQUIRED_FILES)} files from {repo_id}") | |
| local_paths = [Path(hf_hub_download(repo_id, f)) for f in REQUIRED_FILES] | |
| # The repo root in the local cache is the parent of `model.pt`. | |
| repo_root = local_paths[0].parent | |
| print(f"[inference] Cached at: {repo_root}") | |
| return repo_root | |
| def build_model(args_json: dict): | |
| """Re-create `ConditionalDiffusionModel` from the training args dict. | |
| Importing the model classes from the just-downloaded `src/` package is | |
| the safest way to avoid drift between weights and architecture: if the | |
| repo ships a particular version of the U-Net code, that's the version | |
| we use. | |
| """ | |
| from unet_conditional import ConditionalUNet | |
| from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion | |
| unet = ConditionalUNet( | |
| in_channels=1, | |
| out_channels=1, | |
| label_dim=args_json["label_dim"], | |
| base_channels=args_json["base_channels"], | |
| channel_multipliers=tuple(args_json["channel_multipliers"]), | |
| attention_levels=tuple(args_json["attention_levels"]), | |
| dropout=args_json["dropout"], | |
| ) | |
| diffusion = GaussianDiffusion( | |
| timesteps=args_json["timesteps"], | |
| beta_start=args_json["beta_start"], | |
| beta_end=args_json["beta_end"], | |
| schedule_type=args_json["schedule_type"], | |
| ) | |
| return ConditionalDiffusionModel(unet, diffusion) | |
| def load_weights(model: torch.nn.Module, ckpt_path: Path, device: str) -> None: | |
| """Load the state-dict produced by `train_conditional.py`. | |
| The checkpoint is a dict with keys: | |
| model_state_dict, optimizer_state_dict, ema_shadow, epoch, loss, ... | |
| We only need `model_state_dict` for inference. | |
| """ | |
| # weights_only=False because the checkpoint also serialises optimizer | |
| # state, EMA shadows, scheduler, etc. Safe here because we trust the | |
| # source (the file came from our own training run on the cluster). | |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) | |
| if "model_state_dict" not in ckpt: | |
| raise KeyError( | |
| f"{ckpt_path} doesn't contain 'model_state_dict' -- got keys: {list(ckpt)}" | |
| ) | |
| model.load_state_dict(ckpt["model_state_dict"]) | |
| epoch = ckpt.get("epoch", "?") | |
| loss = ckpt.get("loss", "?") | |
| print(f"[inference] Loaded weights (epoch={epoch}, loss={loss})") | |
| def save_outputs(sample: torch.Tensor, output_dir: Path, model_name: str) -> None: | |
| """Write the generated map to disk both as raw .npy and as a PNG preview.""" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # `sample` is shape (1, 1, 256, 256) in [-1, 1]; squeeze and bring to CPU. | |
| arr = sample.squeeze().detach().cpu().numpy() | |
| npy_path = output_dir / f"sample_{model_name}.npy" | |
| np.save(npy_path, arr) | |
| print(f"[inference] Wrote {npy_path} shape={arr.shape} range=[{arr.min():.3f}, {arr.max():.3f}]") | |
| # Optional PNG -- only if matplotlib is around. Keeps the hard dependency | |
| # list short (matplotlib isn't strictly needed for the science workflow). | |
| try: | |
| import matplotlib.pyplot as plt | |
| except ImportError: | |
| print("[inference] matplotlib not installed -- skipping PNG preview.") | |
| return | |
| png_path = output_dir / f"sample_{model_name}.png" | |
| plt.figure(figsize=(5, 5)) | |
| plt.imshow(arr, cmap="inferno", origin="lower") | |
| plt.axis("off") | |
| plt.title(f"DDPM {model_name} sample") | |
| plt.tight_layout() | |
| plt.savefig(png_path, dpi=120, bbox_inches="tight") | |
| plt.close() | |
| print(f"[inference] Wrote {png_path}") | |
| def main() -> None: | |
| args = parse_args() | |
| repo_id = args.repo or DEFAULT_REPOS[args.model] | |
| # ---------------------------------------------------------------------- | |
| # 1. Pull files from the Hub and make src/ importable | |
| # ---------------------------------------------------------------------- | |
| repo_root = download_repo(repo_id) | |
| sys.path.insert(0, str(repo_root / "src")) | |
| # ---------------------------------------------------------------------- | |
| # 2. Rebuild the model from args.json | |
| # ---------------------------------------------------------------------- | |
| with open(repo_root / "args.json") as f: | |
| train_args = json.load(f) | |
| expected_dim = train_args["label_dim"] | |
| if expected_dim != (2 if args.model == "2param" else 6): | |
| raise ValueError( | |
| f"args.json says label_dim={expected_dim} but --model={args.model}; " | |
| "did you point --repo at the wrong checkpoint?" | |
| ) | |
| model = build_model(train_args).to(args.device) | |
| load_weights(model, repo_root / "model.pt", args.device) | |
| model.eval() | |
| # ---------------------------------------------------------------------- | |
| # 3. Build the conditioning vector | |
| # ---------------------------------------------------------------------- | |
| # By default we feed zeros, i.e. the training-set mean in the normalised | |
| # space. To condition on physical (Ωm, σ8, ...) values, z-score them | |
| # using the train-split statistics produced by `dataset_conditional.py` | |
| # and pass the result via --labels. | |
| if args.labels is None: | |
| labels = torch.zeros((1, expected_dim), device=args.device) | |
| print(f"[inference] Using zero (training-mean) conditioning, label_dim={expected_dim}") | |
| else: | |
| if len(args.labels) != expected_dim: | |
| raise ValueError( | |
| f"--labels has {len(args.labels)} entries but model expects {expected_dim}" | |
| ) | |
| labels = torch.tensor([args.labels], dtype=torch.float32, device=args.device) | |
| print(f"[inference] Using user-supplied labels: {args.labels}") | |
| # ---------------------------------------------------------------------- | |
| # 4. Sample | |
| # ---------------------------------------------------------------------- | |
| # Fix the RNG seed for reproducibility -- diffusion sampling is very | |
| # sensitive to the initial Gaussian noise. | |
| torch.manual_seed(args.seed) | |
| if args.device.startswith("cuda"): | |
| torch.cuda.manual_seed_all(args.seed) | |
| use_ddim = not args.no_ddim | |
| print( | |
| f"[inference] Sampling 1 image with " | |
| f"{'DDIM ' + str(args.ddim_steps) + ' steps' if use_ddim else 'DDPM ' + str(train_args['timesteps']) + ' steps'} " | |
| f"on {args.device} ..." | |
| ) | |
| with torch.no_grad(): | |
| sample = model.sample( | |
| labels=labels, | |
| channels=1, | |
| height=256, | |
| width=256, | |
| device=args.device, | |
| progress=True, | |
| use_ddim=use_ddim, | |
| ddim_steps=args.ddim_steps, | |
| eta=0.0, | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # 5. Save outputs | |
| # ---------------------------------------------------------------------- | |
| save_outputs(sample, args.output_dir, args.model) | |
| print("[inference] Done.") | |
| if __name__ == "__main__": | |
| main() | |