""" modeling_ddpm_camels.py ======================= Self-contained loader for the conditional DDPM checkpoint hosted on the Hub. Users only need this file + diffusion_conditional.py + unet_conditional.py + config.json + model.safetensors to run inference. """ from __future__ import annotations import json from pathlib import Path from typing import Dict, Tuple, Union import torch from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel from unet_conditional import ConditionalUNet def build_model(config: Dict) -> ConditionalDiffusionModel: """Instantiate the architecture from a config dict.""" unet = ConditionalUNet( in_channels=int(config["in_channels"]), out_channels=int(config["out_channels"]), label_dim=int(config["label_dim"]), base_channels=int(config["base_channels"]), channel_multipliers=list(config["channel_multipliers"]), attention_levels=list(config["attention_levels"]), dropout=float(config["dropout"]), ) diffusion = GaussianDiffusion( timesteps=int(config["timesteps"]), beta_start=float(config["beta_start"]), beta_end=float(config["beta_end"]), schedule_type=str(config["schedule_type"]), ) return ConditionalDiffusionModel(unet, diffusion) def load_pretrained( model_dir: Union[str, Path], device: str = "cuda", ) -> Tuple[ConditionalDiffusionModel, Dict]: """ Load the model and its config from a directory containing: - config.json - model.safetensors (or pytorch_model.bin as fallback) """ model_dir = Path(model_dir) config = json.loads((model_dir / "config.json").read_text()) model = build_model(config).to(device) safetensors_path = model_dir / "model.safetensors" bin_path = model_dir / "pytorch_model.bin" if safetensors_path.exists(): from safetensors.torch import load_file state_dict = load_file(str(safetensors_path), device=device) elif bin_path.exists(): state_dict = torch.load(bin_path, map_location=device, weights_only=True) else: raise FileNotFoundError(f"No model weights in {model_dir}") # Allow partial-match loading for backward compatibility missing, unexpected = model.load_state_dict(state_dict, strict=False) if missing: print(f" Warning: missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") if unexpected: print(f" Warning: unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") model.eval() for p in model.parameters(): p.requires_grad_(False) return model, config # Convenience for one-shot inference def generate( model: ConditionalDiffusionModel, config: Dict, raw_labels: torch.Tensor, # (B, label_dim) — un-normalised cosmological params n_samples: int = 1, use_ddim: bool = True, ddim_steps: int = None, device: str = "cuda", ) -> torch.Tensor: """ Generate samples conditioned on raw (un-normalised) parameter values. Returns: tensor of shape (B*n_samples, 1, H, W) in [-1, 1] model space. """ if ddim_steps is None: ddim_steps = config["ddim_steps_default"] label_mu = torch.tensor(config["label_mu"], dtype=torch.float32, device=device) label_std = torch.tensor(config["label_std"], dtype=torch.float32, device=device) raw_labels = raw_labels.to(device) norm_labels = (raw_labels - label_mu) / label_std norm_labels = norm_labels.repeat_interleave(n_samples, dim=0) H = W = config["image_size"] return model.sample( labels=norm_labels, channels=1, height=H, width=W, use_ddim=use_ddim, ddim_steps=ddim_steps, progress=False, device=device, )