| """ |
| modeling_ddpm_camels.py |
| ======================= |
| Self-contained loader for the conditional DDPM checkpoint hosted on the Hub. |
| Users only need this file + diffusion_conditional.py + unet_conditional.py |
| + config.json + model.safetensors to run inference. |
| """ |
| from __future__ import annotations |
| import json |
| from pathlib import Path |
| from typing import Dict, Tuple, Union |
|
|
| import torch |
|
|
| from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel |
| from unet_conditional import ConditionalUNet |
|
|
|
|
| def build_model(config: Dict) -> ConditionalDiffusionModel: |
| """Instantiate the architecture from a config dict.""" |
| unet = ConditionalUNet( |
| in_channels=int(config["in_channels"]), |
| out_channels=int(config["out_channels"]), |
| label_dim=int(config["label_dim"]), |
| base_channels=int(config["base_channels"]), |
| channel_multipliers=list(config["channel_multipliers"]), |
| attention_levels=list(config["attention_levels"]), |
| dropout=float(config["dropout"]), |
| ) |
| diffusion = GaussianDiffusion( |
| timesteps=int(config["timesteps"]), |
| beta_start=float(config["beta_start"]), |
| beta_end=float(config["beta_end"]), |
| schedule_type=str(config["schedule_type"]), |
| ) |
| return ConditionalDiffusionModel(unet, diffusion) |
|
|
|
|
| def load_pretrained( |
| model_dir: Union[str, Path], |
| device: str = "cuda", |
| ) -> Tuple[ConditionalDiffusionModel, Dict]: |
| """ |
| Load the model and its config from a directory containing: |
| - config.json |
| - model.safetensors (or pytorch_model.bin as fallback) |
| """ |
| model_dir = Path(model_dir) |
| config = json.loads((model_dir / "config.json").read_text()) |
|
|
| model = build_model(config).to(device) |
|
|
| safetensors_path = model_dir / "model.safetensors" |
| bin_path = model_dir / "pytorch_model.bin" |
| if safetensors_path.exists(): |
| from safetensors.torch import load_file |
| state_dict = load_file(str(safetensors_path), device=device) |
| elif bin_path.exists(): |
| state_dict = torch.load(bin_path, map_location=device, weights_only=True) |
| else: |
| raise FileNotFoundError(f"No model weights in {model_dir}") |
|
|
| |
| missing, unexpected = model.load_state_dict(state_dict, strict=False) |
| if missing: |
| print(f" Warning: missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") |
| if unexpected: |
| print(f" Warning: unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") |
|
|
| model.eval() |
| for p in model.parameters(): |
| p.requires_grad_(False) |
|
|
| return model, config |
|
|
|
|
| |
| def generate( |
| model: ConditionalDiffusionModel, |
| config: Dict, |
| raw_labels: torch.Tensor, |
| n_samples: int = 1, |
| use_ddim: bool = True, |
| ddim_steps: int = None, |
| device: str = "cuda", |
| ) -> torch.Tensor: |
| """ |
| Generate samples conditioned on raw (un-normalised) parameter values. |
| |
| Returns: tensor of shape (B*n_samples, 1, H, W) in [-1, 1] model space. |
| """ |
| if ddim_steps is None: |
| ddim_steps = config["ddim_steps_default"] |
|
|
| label_mu = torch.tensor(config["label_mu"], dtype=torch.float32, device=device) |
| label_std = torch.tensor(config["label_std"], dtype=torch.float32, device=device) |
|
|
| raw_labels = raw_labels.to(device) |
| norm_labels = (raw_labels - label_mu) / label_std |
| norm_labels = norm_labels.repeat_interleave(n_samples, dim=0) |
|
|
| H = W = config["image_size"] |
| return model.sample( |
| labels=norm_labels, channels=1, height=H, width=W, |
| use_ddim=use_ddim, ddim_steps=ddim_steps, |
| progress=False, device=device, |
| ) |
|
|