my-diffusion-model / sample_modeling_ddpm_camels.py
collins909's picture
Upload 4 files
e00756d verified
"""
modeling_ddpm_camels.py
=======================
Self-contained loader for the conditional DDPM checkpoint hosted on the Hub.
Users only need this file + diffusion_conditional.py + unet_conditional.py
+ config.json + model.safetensors to run inference.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, Tuple, Union
import torch
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
from unet_conditional import ConditionalUNet
def build_model(config: Dict) -> ConditionalDiffusionModel:
"""Instantiate the architecture from a config dict."""
unet = ConditionalUNet(
in_channels=int(config["in_channels"]),
out_channels=int(config["out_channels"]),
label_dim=int(config["label_dim"]),
base_channels=int(config["base_channels"]),
channel_multipliers=list(config["channel_multipliers"]),
attention_levels=list(config["attention_levels"]),
dropout=float(config["dropout"]),
)
diffusion = GaussianDiffusion(
timesteps=int(config["timesteps"]),
beta_start=float(config["beta_start"]),
beta_end=float(config["beta_end"]),
schedule_type=str(config["schedule_type"]),
)
return ConditionalDiffusionModel(unet, diffusion)
def load_pretrained(
model_dir: Union[str, Path],
device: str = "cuda",
) -> Tuple[ConditionalDiffusionModel, Dict]:
"""
Load the model and its config from a directory containing:
- config.json
- model.safetensors (or pytorch_model.bin as fallback)
"""
model_dir = Path(model_dir)
config = json.loads((model_dir / "config.json").read_text())
model = build_model(config).to(device)
safetensors_path = model_dir / "model.safetensors"
bin_path = model_dir / "pytorch_model.bin"
if safetensors_path.exists():
from safetensors.torch import load_file
state_dict = load_file(str(safetensors_path), device=device)
elif bin_path.exists():
state_dict = torch.load(bin_path, map_location=device, weights_only=True)
else:
raise FileNotFoundError(f"No model weights in {model_dir}")
# Allow partial-match loading for backward compatibility
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f" Warning: missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
if unexpected:
print(f" Warning: unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")
model.eval()
for p in model.parameters():
p.requires_grad_(False)
return model, config
# Convenience for one-shot inference
def generate(
model: ConditionalDiffusionModel,
config: Dict,
raw_labels: torch.Tensor, # (B, label_dim) — un-normalised cosmological params
n_samples: int = 1,
use_ddim: bool = True,
ddim_steps: int = None,
device: str = "cuda",
) -> torch.Tensor:
"""
Generate samples conditioned on raw (un-normalised) parameter values.
Returns: tensor of shape (B*n_samples, 1, H, W) in [-1, 1] model space.
"""
if ddim_steps is None:
ddim_steps = config["ddim_steps_default"]
label_mu = torch.tensor(config["label_mu"], dtype=torch.float32, device=device)
label_std = torch.tensor(config["label_std"], dtype=torch.float32, device=device)
raw_labels = raw_labels.to(device)
norm_labels = (raw_labels - label_mu) / label_std
norm_labels = norm_labels.repeat_interleave(n_samples, dim=0)
H = W = config["image_size"]
return model.sample(
labels=norm_labels, channels=1, height=H, width=W,
use_ddim=use_ddim, ddim_steps=ddim_steps,
progress=False, device=device,
)