Spaces:
Sleeping
Sleeping
File size: 5,138 Bytes
67d59a3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import functools
import importlib.util
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import jax
from flax import nnx, serialization
from huggingface_hub import hf_hub_download
from diffuse.diffusion.sde import Flow
from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator
from diffuse.integrator.stochastic import EulerMaruyamaIntegrator
from diffuse.predictor import Predictor
from diffuse.timer.base import VpTimer
from diffuse.denoisers.denoiser import Denoiser
@dataclass(frozen=True)
class PipelineAssets:
"""Container holding the preloaded model artifacts."""
model: Any
flow: Flow
predictor: Predictor
x0_shape: Tuple[int, int, int]
@functools.lru_cache(maxsize=1)
def load_pipeline_assets() -> PipelineAssets:
"""Download the HF model and build the predictor stack once."""
model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")
spec = importlib.util.spec_from_file_location("model_config", config_path)
if spec is None or spec.loader is None:
raise RuntimeError("Unable to load model config from Hugging Face hub.")
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
model = config_module.model
with open(model_path, "rb") as f:
state_dict = serialization.from_bytes(None, f.read())
graphdef, state = nnx.split(model)
state.replace_by_pure_dict(state_dict)
model = nnx.merge(graphdef, state)
model.eval()
flow = Flow(tf=1.0)
predictor = Predictor(
model=flow,
network=lambda x, t: model(x, t).output,
prediction_type="velocity",
)
return PipelineAssets(
model=model,
flow=flow,
predictor=predictor,
x0_shape=(28, 28, 1),
)
INTEGRATORS: Dict[str, Dict[str, Any]] = {
"ddim": {
"label": "DDIM (Deterministic)",
"cls": DDIMIntegrator,
"description": "Deterministic DDIM sampler.",
"supports_churn": True,
},
"heun": {
"label": "Heun (Deterministic 2nd order)",
"cls": HeunIntegrator,
"description": "Second-order deterministic integrator.",
"supports_churn": True,
},
"euler": {
"label": "Euler (Deterministic)",
"cls": EulerIntegrator,
"description": "Forward Euler integrator.",
"supports_churn": True,
},
"dpmpp2s": {
"label": "DPM++ 2S (Deterministic multi-step)",
"cls": DPMpp2sIntegrator,
"description": "Deterministic multi-step sampler with second-order accuracy.",
"supports_churn": True,
},
"euler_maruyama": {
"label": "Euler-Maruyama (Stochastic)",
"cls": EulerMaruyamaIntegrator,
"description": "Stochastic sampler with noise at each diffusion step.",
"supports_churn": False,
},
}
LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()}
def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]:
"""Resolve either an integrator key or display label to the configuration dict."""
if identifier in INTEGRATORS:
return identifier, INTEGRATORS[identifier]
if identifier in LABEL_TO_KEY:
key = LABEL_TO_KEY[identifier]
return key, INTEGRATORS[key]
raise KeyError(f"Unknown integrator identifier: {identifier}")
def build_denoiser(
integrator_key: str,
n_steps: int,
*,
churn_params: Optional[Dict[str, float]] = None,
) -> Denoiser:
"""Instantiate a denoiser wired with the requested integrator and timer."""
if n_steps < 1:
raise ValueError("n_steps must be >= 1")
assets = load_pipeline_assets()
_, integrator_cfg = resolve_integrator(integrator_key)
timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0)
integrator_kwargs: Dict[str, float] = {}
if churn_params:
if not integrator_cfg.get("supports_churn", False):
raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.")
integrator_kwargs = churn_params
integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs)
return Denoiser(
integrator=integrator,
model=assets.flow,
predictor=assets.predictor,
x0_shape=assets.x0_shape,
)
def sample_batch(
integrator_identifier: str,
*,
n_steps: int,
n_samples: int,
seed: int,
keep_history: bool = False,
churn_params: Optional[Dict[str, float]] = None,
):
"""Generate a batch of samples for the requested integrator."""
if n_samples < 1:
raise ValueError("n_samples must be >= 1")
denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params)
key = jax.random.PRNGKey(seed)
# The denoiser expects the number of steps to match the timer configuration.
state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history)
return state, history
|