import functools import importlib.util from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple import jax from flax import nnx, serialization from huggingface_hub import hf_hub_download from diffuse.diffusion.sde import Flow from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator from diffuse.integrator.stochastic import EulerMaruyamaIntegrator from diffuse.predictor import Predictor from diffuse.timer.base import VpTimer from diffuse.denoisers.denoiser import Denoiser @dataclass(frozen=True) class PipelineAssets: """Container holding the preloaded model artifacts.""" model: Any flow: Flow predictor: Predictor x0_shape: Tuple[int, int, int] @functools.lru_cache(maxsize=1) def load_pipeline_assets() -> PipelineAssets: """Download the HF model and build the predictor stack once.""" model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack") config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py") spec = importlib.util.spec_from_file_location("model_config", config_path) if spec is None or spec.loader is None: raise RuntimeError("Unable to load model config from Hugging Face hub.") config_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(config_module) model = config_module.model with open(model_path, "rb") as f: state_dict = serialization.from_bytes(None, f.read()) graphdef, state = nnx.split(model) state.replace_by_pure_dict(state_dict) model = nnx.merge(graphdef, state) model.eval() flow = Flow(tf=1.0) predictor = Predictor( model=flow, network=lambda x, t: model(x, t).output, prediction_type="velocity", ) return PipelineAssets( model=model, flow=flow, predictor=predictor, x0_shape=(28, 28, 1), ) INTEGRATORS: Dict[str, Dict[str, Any]] = { "ddim": { "label": "DDIM (Deterministic)", "cls": DDIMIntegrator, "description": "Deterministic DDIM sampler.", "supports_churn": True, }, "heun": { "label": "Heun (Deterministic 2nd order)", "cls": HeunIntegrator, "description": "Second-order deterministic integrator.", "supports_churn": True, }, "euler": { "label": "Euler (Deterministic)", "cls": EulerIntegrator, "description": "Forward Euler integrator.", "supports_churn": True, }, "dpmpp2s": { "label": "DPM++ 2S (Deterministic multi-step)", "cls": DPMpp2sIntegrator, "description": "Deterministic multi-step sampler with second-order accuracy.", "supports_churn": True, }, "euler_maruyama": { "label": "Euler-Maruyama (Stochastic)", "cls": EulerMaruyamaIntegrator, "description": "Stochastic sampler with noise at each diffusion step.", "supports_churn": False, }, } LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()} def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]: """Resolve either an integrator key or display label to the configuration dict.""" if identifier in INTEGRATORS: return identifier, INTEGRATORS[identifier] if identifier in LABEL_TO_KEY: key = LABEL_TO_KEY[identifier] return key, INTEGRATORS[key] raise KeyError(f"Unknown integrator identifier: {identifier}") def build_denoiser( integrator_key: str, n_steps: int, *, churn_params: Optional[Dict[str, float]] = None, ) -> Denoiser: """Instantiate a denoiser wired with the requested integrator and timer.""" if n_steps < 1: raise ValueError("n_steps must be >= 1") assets = load_pipeline_assets() _, integrator_cfg = resolve_integrator(integrator_key) timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0) integrator_kwargs: Dict[str, float] = {} if churn_params: if not integrator_cfg.get("supports_churn", False): raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.") integrator_kwargs = churn_params integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs) return Denoiser( integrator=integrator, model=assets.flow, predictor=assets.predictor, x0_shape=assets.x0_shape, ) def sample_batch( integrator_identifier: str, *, n_steps: int, n_samples: int, seed: int, keep_history: bool = False, churn_params: Optional[Dict[str, float]] = None, ): """Generate a batch of samples for the requested integrator.""" if n_samples < 1: raise ValueError("n_samples must be >= 1") denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params) key = jax.random.PRNGKey(seed) # The denoiser expects the number of steps to match the timer configuration. state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history) return state, history