File size: 3,775 Bytes
e00756d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
modeling_ddpm_camels.py
=======================
Self-contained loader for the conditional DDPM checkpoint hosted on the Hub.
Users only need this file + diffusion_conditional.py + unet_conditional.py
+ config.json + model.safetensors  to run inference.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Dict, Tuple, Union

import torch

from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel
from unet_conditional import ConditionalUNet


def build_model(config: Dict) -> ConditionalDiffusionModel:
    """Instantiate the architecture from a config dict."""
    unet = ConditionalUNet(
        in_channels=int(config["in_channels"]),
        out_channels=int(config["out_channels"]),
        label_dim=int(config["label_dim"]),
        base_channels=int(config["base_channels"]),
        channel_multipliers=list(config["channel_multipliers"]),
        attention_levels=list(config["attention_levels"]),
        dropout=float(config["dropout"]),
    )
    diffusion = GaussianDiffusion(
        timesteps=int(config["timesteps"]),
        beta_start=float(config["beta_start"]),
        beta_end=float(config["beta_end"]),
        schedule_type=str(config["schedule_type"]),
    )
    return ConditionalDiffusionModel(unet, diffusion)


def load_pretrained(
    model_dir: Union[str, Path],
    device: str = "cuda",
) -> Tuple[ConditionalDiffusionModel, Dict]:
    """
    Load the model and its config from a directory containing:
      - config.json
      - model.safetensors  (or pytorch_model.bin as fallback)
    """
    model_dir = Path(model_dir)
    config = json.loads((model_dir / "config.json").read_text())

    model = build_model(config).to(device)

    safetensors_path = model_dir / "model.safetensors"
    bin_path         = model_dir / "pytorch_model.bin"
    if safetensors_path.exists():
        from safetensors.torch import load_file
        state_dict = load_file(str(safetensors_path), device=device)
    elif bin_path.exists():
        state_dict = torch.load(bin_path, map_location=device, weights_only=True)
    else:
        raise FileNotFoundError(f"No model weights in {model_dir}")

    # Allow partial-match loading for backward compatibility
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing:
        print(f"  Warning: missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}")
    if unexpected:
        print(f"  Warning: unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}")

    model.eval()
    for p in model.parameters():
        p.requires_grad_(False)

    return model, config


# Convenience for one-shot inference
def generate(
    model: ConditionalDiffusionModel,
    config: Dict,
    raw_labels: torch.Tensor,        # (B, label_dim) — un-normalised cosmological params
    n_samples: int = 1,
    use_ddim: bool = True,
    ddim_steps: int = None,
    device: str = "cuda",
) -> torch.Tensor:
    """
    Generate samples conditioned on raw (un-normalised) parameter values.

    Returns: tensor of shape (B*n_samples, 1, H, W) in [-1, 1] model space.
    """
    if ddim_steps is None:
        ddim_steps = config["ddim_steps_default"]

    label_mu  = torch.tensor(config["label_mu"],  dtype=torch.float32, device=device)
    label_std = torch.tensor(config["label_std"], dtype=torch.float32, device=device)

    raw_labels = raw_labels.to(device)
    norm_labels = (raw_labels - label_mu) / label_std
    norm_labels = norm_labels.repeat_interleave(n_samples, dim=0)

    H = W = config["image_size"]
    return model.sample(
        labels=norm_labels, channels=1, height=H, width=W,
        use_ddim=use_ddim, ddim_steps=ddim_steps,
        progress=False, device=device,
    )