Spaces:
Sleeping
Sleeping
| import functools | |
| import importlib.util | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional, Tuple | |
| import jax | |
| from flax import nnx, serialization | |
| from huggingface_hub import hf_hub_download | |
| from diffuse.diffusion.sde import Flow | |
| from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator | |
| from diffuse.integrator.stochastic import EulerMaruyamaIntegrator | |
| from diffuse.predictor import Predictor | |
| from diffuse.timer.base import VpTimer | |
| from diffuse.denoisers.denoiser import Denoiser | |
| class PipelineAssets: | |
| """Container holding the preloaded model artifacts.""" | |
| model: Any | |
| flow: Flow | |
| predictor: Predictor | |
| x0_shape: Tuple[int, int, int] | |
| def load_pipeline_assets() -> PipelineAssets: | |
| """Download the HF model and build the predictor stack once.""" | |
| model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack") | |
| config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py") | |
| spec = importlib.util.spec_from_file_location("model_config", config_path) | |
| if spec is None or spec.loader is None: | |
| raise RuntimeError("Unable to load model config from Hugging Face hub.") | |
| config_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(config_module) | |
| model = config_module.model | |
| with open(model_path, "rb") as f: | |
| state_dict = serialization.from_bytes(None, f.read()) | |
| graphdef, state = nnx.split(model) | |
| state.replace_by_pure_dict(state_dict) | |
| model = nnx.merge(graphdef, state) | |
| model.eval() | |
| flow = Flow(tf=1.0) | |
| predictor = Predictor( | |
| model=flow, | |
| network=lambda x, t: model(x, t).output, | |
| prediction_type="velocity", | |
| ) | |
| return PipelineAssets( | |
| model=model, | |
| flow=flow, | |
| predictor=predictor, | |
| x0_shape=(28, 28, 1), | |
| ) | |
| INTEGRATORS: Dict[str, Dict[str, Any]] = { | |
| "ddim": { | |
| "label": "DDIM (Deterministic)", | |
| "cls": DDIMIntegrator, | |
| "description": "Deterministic DDIM sampler.", | |
| "supports_churn": True, | |
| }, | |
| "heun": { | |
| "label": "Heun (Deterministic 2nd order)", | |
| "cls": HeunIntegrator, | |
| "description": "Second-order deterministic integrator.", | |
| "supports_churn": True, | |
| }, | |
| "euler": { | |
| "label": "Euler (Deterministic)", | |
| "cls": EulerIntegrator, | |
| "description": "Forward Euler integrator.", | |
| "supports_churn": True, | |
| }, | |
| "dpmpp2s": { | |
| "label": "DPM++ 2S (Deterministic multi-step)", | |
| "cls": DPMpp2sIntegrator, | |
| "description": "Deterministic multi-step sampler with second-order accuracy.", | |
| "supports_churn": True, | |
| }, | |
| "euler_maruyama": { | |
| "label": "Euler-Maruyama (Stochastic)", | |
| "cls": EulerMaruyamaIntegrator, | |
| "description": "Stochastic sampler with noise at each diffusion step.", | |
| "supports_churn": False, | |
| }, | |
| } | |
| LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()} | |
| def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]: | |
| """Resolve either an integrator key or display label to the configuration dict.""" | |
| if identifier in INTEGRATORS: | |
| return identifier, INTEGRATORS[identifier] | |
| if identifier in LABEL_TO_KEY: | |
| key = LABEL_TO_KEY[identifier] | |
| return key, INTEGRATORS[key] | |
| raise KeyError(f"Unknown integrator identifier: {identifier}") | |
| def build_denoiser( | |
| integrator_key: str, | |
| n_steps: int, | |
| *, | |
| churn_params: Optional[Dict[str, float]] = None, | |
| ) -> Denoiser: | |
| """Instantiate a denoiser wired with the requested integrator and timer.""" | |
| if n_steps < 1: | |
| raise ValueError("n_steps must be >= 1") | |
| assets = load_pipeline_assets() | |
| _, integrator_cfg = resolve_integrator(integrator_key) | |
| timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0) | |
| integrator_kwargs: Dict[str, float] = {} | |
| if churn_params: | |
| if not integrator_cfg.get("supports_churn", False): | |
| raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.") | |
| integrator_kwargs = churn_params | |
| integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs) | |
| return Denoiser( | |
| integrator=integrator, | |
| model=assets.flow, | |
| predictor=assets.predictor, | |
| x0_shape=assets.x0_shape, | |
| ) | |
| def sample_batch( | |
| integrator_identifier: str, | |
| *, | |
| n_steps: int, | |
| n_samples: int, | |
| seed: int, | |
| keep_history: bool = False, | |
| churn_params: Optional[Dict[str, float]] = None, | |
| ): | |
| """Generate a batch of samples for the requested integrator.""" | |
| if n_samples < 1: | |
| raise ValueError("n_samples must be >= 1") | |
| denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params) | |
| key = jax.random.PRNGKey(seed) | |
| # The denoiser expects the number of steps to match the timer configuration. | |
| state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history) | |
| return state, history | |