DiffuseIntegrator / pipeline.py
Geoffroy38000's picture
Polish diffusion UI and churning controls
67d59a3
import functools
import importlib.util
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import jax
from flax import nnx, serialization
from huggingface_hub import hf_hub_download
from diffuse.diffusion.sde import Flow
from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator
from diffuse.integrator.stochastic import EulerMaruyamaIntegrator
from diffuse.predictor import Predictor
from diffuse.timer.base import VpTimer
from diffuse.denoisers.denoiser import Denoiser
@dataclass(frozen=True)
class PipelineAssets:
"""Container holding the preloaded model artifacts."""
model: Any
flow: Flow
predictor: Predictor
x0_shape: Tuple[int, int, int]
@functools.lru_cache(maxsize=1)
def load_pipeline_assets() -> PipelineAssets:
"""Download the HF model and build the predictor stack once."""
model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")
spec = importlib.util.spec_from_file_location("model_config", config_path)
if spec is None or spec.loader is None:
raise RuntimeError("Unable to load model config from Hugging Face hub.")
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
model = config_module.model
with open(model_path, "rb") as f:
state_dict = serialization.from_bytes(None, f.read())
graphdef, state = nnx.split(model)
state.replace_by_pure_dict(state_dict)
model = nnx.merge(graphdef, state)
model.eval()
flow = Flow(tf=1.0)
predictor = Predictor(
model=flow,
network=lambda x, t: model(x, t).output,
prediction_type="velocity",
)
return PipelineAssets(
model=model,
flow=flow,
predictor=predictor,
x0_shape=(28, 28, 1),
)
INTEGRATORS: Dict[str, Dict[str, Any]] = {
"ddim": {
"label": "DDIM (Deterministic)",
"cls": DDIMIntegrator,
"description": "Deterministic DDIM sampler.",
"supports_churn": True,
},
"heun": {
"label": "Heun (Deterministic 2nd order)",
"cls": HeunIntegrator,
"description": "Second-order deterministic integrator.",
"supports_churn": True,
},
"euler": {
"label": "Euler (Deterministic)",
"cls": EulerIntegrator,
"description": "Forward Euler integrator.",
"supports_churn": True,
},
"dpmpp2s": {
"label": "DPM++ 2S (Deterministic multi-step)",
"cls": DPMpp2sIntegrator,
"description": "Deterministic multi-step sampler with second-order accuracy.",
"supports_churn": True,
},
"euler_maruyama": {
"label": "Euler-Maruyama (Stochastic)",
"cls": EulerMaruyamaIntegrator,
"description": "Stochastic sampler with noise at each diffusion step.",
"supports_churn": False,
},
}
LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()}
def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]:
"""Resolve either an integrator key or display label to the configuration dict."""
if identifier in INTEGRATORS:
return identifier, INTEGRATORS[identifier]
if identifier in LABEL_TO_KEY:
key = LABEL_TO_KEY[identifier]
return key, INTEGRATORS[key]
raise KeyError(f"Unknown integrator identifier: {identifier}")
def build_denoiser(
integrator_key: str,
n_steps: int,
*,
churn_params: Optional[Dict[str, float]] = None,
) -> Denoiser:
"""Instantiate a denoiser wired with the requested integrator and timer."""
if n_steps < 1:
raise ValueError("n_steps must be >= 1")
assets = load_pipeline_assets()
_, integrator_cfg = resolve_integrator(integrator_key)
timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0)
integrator_kwargs: Dict[str, float] = {}
if churn_params:
if not integrator_cfg.get("supports_churn", False):
raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.")
integrator_kwargs = churn_params
integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs)
return Denoiser(
integrator=integrator,
model=assets.flow,
predictor=assets.predictor,
x0_shape=assets.x0_shape,
)
def sample_batch(
integrator_identifier: str,
*,
n_steps: int,
n_samples: int,
seed: int,
keep_history: bool = False,
churn_params: Optional[Dict[str, float]] = None,
):
"""Generate a batch of samples for the requested integrator."""
if n_samples < 1:
raise ValueError("n_samples must be >= 1")
denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params)
key = jax.random.PRNGKey(seed)
# The denoiser expects the number of steps to match the timer configuration.
state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history)
return state, history