File size: 5,138 Bytes
67d59a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import functools
import importlib.util
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

import jax
from flax import nnx, serialization
from huggingface_hub import hf_hub_download

from diffuse.diffusion.sde import Flow
from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator
from diffuse.integrator.stochastic import EulerMaruyamaIntegrator
from diffuse.predictor import Predictor
from diffuse.timer.base import VpTimer
from diffuse.denoisers.denoiser import Denoiser


@dataclass(frozen=True)
class PipelineAssets:
    """Container holding the preloaded model artifacts."""

    model: Any
    flow: Flow
    predictor: Predictor
    x0_shape: Tuple[int, int, int]


@functools.lru_cache(maxsize=1)
def load_pipeline_assets() -> PipelineAssets:
    """Download the HF model and build the predictor stack once."""
    model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
    config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")

    spec = importlib.util.spec_from_file_location("model_config", config_path)
    if spec is None or spec.loader is None:
        raise RuntimeError("Unable to load model config from Hugging Face hub.")
    config_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(config_module)

    model = config_module.model

    with open(model_path, "rb") as f:
        state_dict = serialization.from_bytes(None, f.read())

    graphdef, state = nnx.split(model)
    state.replace_by_pure_dict(state_dict)
    model = nnx.merge(graphdef, state)
    model.eval()

    flow = Flow(tf=1.0)
    predictor = Predictor(
        model=flow,
        network=lambda x, t: model(x, t).output,
        prediction_type="velocity",
    )

    return PipelineAssets(
        model=model,
        flow=flow,
        predictor=predictor,
        x0_shape=(28, 28, 1),
    )


INTEGRATORS: Dict[str, Dict[str, Any]] = {
    "ddim": {
        "label": "DDIM (Deterministic)",
        "cls": DDIMIntegrator,
        "description": "Deterministic DDIM sampler.",
        "supports_churn": True,
    },
    "heun": {
        "label": "Heun (Deterministic 2nd order)",
        "cls": HeunIntegrator,
        "description": "Second-order deterministic integrator.",
        "supports_churn": True,
    },
    "euler": {
        "label": "Euler (Deterministic)",
        "cls": EulerIntegrator,
        "description": "Forward Euler integrator.",
        "supports_churn": True,
    },
    "dpmpp2s": {
        "label": "DPM++ 2S (Deterministic multi-step)",
        "cls": DPMpp2sIntegrator,
        "description": "Deterministic multi-step sampler with second-order accuracy.",
        "supports_churn": True,
    },
    "euler_maruyama": {
        "label": "Euler-Maruyama (Stochastic)",
        "cls": EulerMaruyamaIntegrator,
        "description": "Stochastic sampler with noise at each diffusion step.",
        "supports_churn": False,
    },
}

LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()}


def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]:
    """Resolve either an integrator key or display label to the configuration dict."""
    if identifier in INTEGRATORS:
        return identifier, INTEGRATORS[identifier]
    if identifier in LABEL_TO_KEY:
        key = LABEL_TO_KEY[identifier]
        return key, INTEGRATORS[key]
    raise KeyError(f"Unknown integrator identifier: {identifier}")


def build_denoiser(
    integrator_key: str,
    n_steps: int,
    *,
    churn_params: Optional[Dict[str, float]] = None,
) -> Denoiser:
    """Instantiate a denoiser wired with the requested integrator and timer."""
    if n_steps < 1:
        raise ValueError("n_steps must be >= 1")

    assets = load_pipeline_assets()
    _, integrator_cfg = resolve_integrator(integrator_key)

    timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0)
    integrator_kwargs: Dict[str, float] = {}
    if churn_params:
        if not integrator_cfg.get("supports_churn", False):
            raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.")
        integrator_kwargs = churn_params

    integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs)
    return Denoiser(
        integrator=integrator,
        model=assets.flow,
        predictor=assets.predictor,
        x0_shape=assets.x0_shape,
    )


def sample_batch(
    integrator_identifier: str,
    *,
    n_steps: int,
    n_samples: int,
    seed: int,
    keep_history: bool = False,
    churn_params: Optional[Dict[str, float]] = None,
):
    """Generate a batch of samples for the requested integrator."""
    if n_samples < 1:
        raise ValueError("n_samples must be >= 1")

    denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params)
    key = jax.random.PRNGKey(seed)

    # The denoiser expects the number of steps to match the timer configuration.
    state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history)
    return state, history