#!/usr/bin/env python # --------------------------------------------------------------------------- # inference_example.py # # Self-contained example that downloads a conditional-DDPM checkpoint from # the Hugging Face Hub and generates one HI map. # # Works for **both** uploaded models -- the script picks which one to load # from a CLI argument: # # python inference_example.py --model 2param # python inference_example.py --model 6param # python inference_example.py --model 2param --repo myuser/my-fork # python inference_example.py --model 6param --device cuda --ddim-steps 50 # # The script: # 1. Downloads `model.pt`, `args.json`, and the bundled src/*.py files. # 2. Imports `ConditionalUNet` and `GaussianDiffusion` from the downloaded # code (no need for a separate pip-installed package). # 3. Rebuilds the model from `args.json` so weights and architecture # cannot drift apart. # 4. Samples one image with DDIM (or DDPM, with `--no-ddim`). # 5. Saves a `.npy` of the raw [-1, 1] output and a PNG visualisation. # # This file is bundled inside each HF repo so users can grab a single script # and immediately do inference. # --------------------------------------------------------------------------- import argparse import json import sys from pathlib import Path import numpy as np import torch # huggingface_hub is the only "extra" dependency; everything else (torch, # numpy) is already required to run the model. from huggingface_hub import hf_hub_download # -------------------------------------------------------------------------- # Defaults -- adjust here or override via CLI flags # -------------------------------------------------------------------------- DEFAULT_REPOS = { "2param": "collins909/DDPM-2param", "6param": "collins909/DDPM-6param", } # All files we expect to find in every uploaded repo. We download each one # explicitly (rather than `snapshot_download`) so we can give a clear error # message if anything is missing. REQUIRED_FILES = [ "model.pt", "args.json", "src/__init__.py", "src/unet_conditional.py", "src/diffusion_conditional.py", ] def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Sample one HI map from the HF-hosted DDPM.") p.add_argument( "--model", choices=sorted(DEFAULT_REPOS.keys()), required=True, help="Which model to download. Picks the matching default HF repo.", ) p.add_argument( "--repo", default=None, help="Override the HF repo id (default: see DEFAULT_REPOS in this file).", ) p.add_argument( "--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Torch device for sampling. Defaults to cuda if available else cpu.", ) p.add_argument( "--ddim-steps", type=int, default=50, help="Number of DDIM steps (ignored when --no-ddim).", ) p.add_argument( "--no-ddim", action="store_true", help="Use the full DDPM sampler (slow, all 1500 steps) instead of DDIM.", ) p.add_argument( "--seed", type=int, default=0, help="RNG seed for reproducible sampling.", ) p.add_argument( "--labels", type=float, nargs="+", default=None, help=( "Conditioning vector (already z-scored). Length must match label_dim " "(2 or 6). If omitted, an all-zeros vector is used (i.e. the training-set mean)." ), ) p.add_argument( "--output-dir", type=Path, default=Path("inference_outputs"), help="Where to write the generated sample (.npy + .png).", ) return p.parse_args() def download_repo(repo_id: str) -> Path: """Download every required file from `repo_id`, return the local cache dir. We rely on `hf_hub_download` to manage caching -- it stores files under `~/.cache/huggingface/hub/` and returns the local path. We assume all the required files end up in the same directory (which they do, modulo the `src/` subfolder). """ print(f"[inference] Downloading {len(REQUIRED_FILES)} files from {repo_id}") local_paths = [Path(hf_hub_download(repo_id, f)) for f in REQUIRED_FILES] # The repo root in the local cache is the parent of `model.pt`. repo_root = local_paths[0].parent print(f"[inference] Cached at: {repo_root}") return repo_root def build_model(args_json: dict): """Re-create `ConditionalDiffusionModel` from the training args dict. Importing the model classes from the just-downloaded `src/` package is the safest way to avoid drift between weights and architecture: if the repo ships a particular version of the U-Net code, that's the version we use. """ from unet_conditional import ConditionalUNet from diffusion_conditional import ConditionalDiffusionModel, GaussianDiffusion unet = ConditionalUNet( in_channels=1, out_channels=1, label_dim=args_json["label_dim"], base_channels=args_json["base_channels"], channel_multipliers=tuple(args_json["channel_multipliers"]), attention_levels=tuple(args_json["attention_levels"]), dropout=args_json["dropout"], ) diffusion = GaussianDiffusion( timesteps=args_json["timesteps"], beta_start=args_json["beta_start"], beta_end=args_json["beta_end"], schedule_type=args_json["schedule_type"], ) return ConditionalDiffusionModel(unet, diffusion) def load_weights(model: torch.nn.Module, ckpt_path: Path, device: str) -> None: """Load the state-dict produced by `train_conditional.py`. The checkpoint is a dict with keys: model_state_dict, optimizer_state_dict, ema_shadow, epoch, loss, ... We only need `model_state_dict` for inference. """ # weights_only=False because the checkpoint also serialises optimizer # state, EMA shadows, scheduler, etc. Safe here because we trust the # source (the file came from our own training run on the cluster). ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) if "model_state_dict" not in ckpt: raise KeyError( f"{ckpt_path} doesn't contain 'model_state_dict' -- got keys: {list(ckpt)}" ) model.load_state_dict(ckpt["model_state_dict"]) epoch = ckpt.get("epoch", "?") loss = ckpt.get("loss", "?") print(f"[inference] Loaded weights (epoch={epoch}, loss={loss})") def save_outputs(sample: torch.Tensor, output_dir: Path, model_name: str) -> None: """Write the generated map to disk both as raw .npy and as a PNG preview.""" output_dir.mkdir(parents=True, exist_ok=True) # `sample` is shape (1, 1, 256, 256) in [-1, 1]; squeeze and bring to CPU. arr = sample.squeeze().detach().cpu().numpy() npy_path = output_dir / f"sample_{model_name}.npy" np.save(npy_path, arr) print(f"[inference] Wrote {npy_path} shape={arr.shape} range=[{arr.min():.3f}, {arr.max():.3f}]") # Optional PNG -- only if matplotlib is around. Keeps the hard dependency # list short (matplotlib isn't strictly needed for the science workflow). try: import matplotlib.pyplot as plt except ImportError: print("[inference] matplotlib not installed -- skipping PNG preview.") return png_path = output_dir / f"sample_{model_name}.png" plt.figure(figsize=(5, 5)) plt.imshow(arr, cmap="inferno", origin="lower") plt.axis("off") plt.title(f"DDPM {model_name} sample") plt.tight_layout() plt.savefig(png_path, dpi=120, bbox_inches="tight") plt.close() print(f"[inference] Wrote {png_path}") def main() -> None: args = parse_args() repo_id = args.repo or DEFAULT_REPOS[args.model] # ---------------------------------------------------------------------- # 1. Pull files from the Hub and make src/ importable # ---------------------------------------------------------------------- repo_root = download_repo(repo_id) sys.path.insert(0, str(repo_root / "src")) # ---------------------------------------------------------------------- # 2. Rebuild the model from args.json # ---------------------------------------------------------------------- with open(repo_root / "args.json") as f: train_args = json.load(f) expected_dim = train_args["label_dim"] if expected_dim != (2 if args.model == "2param" else 6): raise ValueError( f"args.json says label_dim={expected_dim} but --model={args.model}; " "did you point --repo at the wrong checkpoint?" ) model = build_model(train_args).to(args.device) load_weights(model, repo_root / "model.pt", args.device) model.eval() # ---------------------------------------------------------------------- # 3. Build the conditioning vector # ---------------------------------------------------------------------- # By default we feed zeros, i.e. the training-set mean in the normalised # space. To condition on physical (Ωm, σ8, ...) values, z-score them # using the train-split statistics produced by `dataset_conditional.py` # and pass the result via --labels. if args.labels is None: labels = torch.zeros((1, expected_dim), device=args.device) print(f"[inference] Using zero (training-mean) conditioning, label_dim={expected_dim}") else: if len(args.labels) != expected_dim: raise ValueError( f"--labels has {len(args.labels)} entries but model expects {expected_dim}" ) labels = torch.tensor([args.labels], dtype=torch.float32, device=args.device) print(f"[inference] Using user-supplied labels: {args.labels}") # ---------------------------------------------------------------------- # 4. Sample # ---------------------------------------------------------------------- # Fix the RNG seed for reproducibility -- diffusion sampling is very # sensitive to the initial Gaussian noise. torch.manual_seed(args.seed) if args.device.startswith("cuda"): torch.cuda.manual_seed_all(args.seed) use_ddim = not args.no_ddim print( f"[inference] Sampling 1 image with " f"{'DDIM ' + str(args.ddim_steps) + ' steps' if use_ddim else 'DDPM ' + str(train_args['timesteps']) + ' steps'} " f"on {args.device} ..." ) with torch.no_grad(): sample = model.sample( labels=labels, channels=1, height=256, width=256, device=args.device, progress=True, use_ddim=use_ddim, ddim_steps=args.ddim_steps, eta=0.0, ) # ---------------------------------------------------------------------- # 5. Save outputs # ---------------------------------------------------------------------- save_outputs(sample, args.output_dir, args.model) print("[inference] Done.") if __name__ == "__main__": main()