import os
from pathlib import Path
from typing import List, Optional
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
import jax
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from pipeline import INTEGRATORS, load_pipeline_assets, resolve_integrator, sample_batch
N_SAMPLES = 5
MAX_STEPS = 80
DEFAULT_STEPS = 20
ROOT_DIR = Path(__file__).parent
LOGO_PATH = ROOT_DIR / "logo.png"
LOGO_VALUE = str(LOGO_PATH) if LOGO_PATH.exists() else None
DEFAULT_CHURN_RATE = 0.0
DEFAULT_CHURN_MIN = 0.0
DEFAULT_CHURN_MAX = 0.0
DEFAULT_NOISE_INFLATION = 1.0
MAX_NOISE_INFLATION = 1.02
SUMMARY_PLACEHOLDER_HTML = """
Ready to sample
Select an integrator, adjust the controls, then generate digits to inspect their trajectories.
""".strip()
CUSTOM_CSS = """
body {background: radial-gradient(circle at top left, #ffe8d5, #fff7f0 55%, #fdf1f8);}
#hero {
display: flex;
align-items: center;
justify-content: center;
gap: 1.5rem;
background: rgba(255, 255, 255, 0.85);
padding: 1.5rem 2rem;
border-radius: 18px;
box-shadow: 0 18px 35px rgba(255, 135, 0, 0.15);
border: 1px solid rgba(255, 145, 0, 0.35);
}
.hero-logo img {max-width: 320px; width: 100%; object-fit: contain;}
.hero-copy {font-size: 1.05rem !important; color: #7a3b09;}
.control-card {
background: rgba(255, 255, 255, 0.92);
border-radius: 16px;
padding: 1.25rem;
border: 1px solid rgba(255, 166, 77, 0.35);
box-shadow: 0 14px 30px rgba(255, 140, 0, 0.12);
}
.generate-button button {
background: linear-gradient(135deg, #ff7e00, #ffb347);
color: #fff;
font-weight: 600;
border-radius: 12px;
box-shadow: 0 10px 20px rgba(255, 126, 0, 0.25);
}
.generate-button button:hover {filter: brightness(1.05);}
.control-heading {
font-weight: 600;
color: #7a3b09;
margin-bottom: 0.6rem !important;
}
.plot-card {
background: rgba(255, 255, 255, 0.88);
border-radius: 16px;
padding: 1rem;
border: 1px solid rgba(255, 166, 77, 0.35);
box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.4), 0 12px 28px rgba(255, 145, 0, 0.18);
}
.details-card {
border: none;
padding: 0;
}
.summary-card {
background: rgba(255, 255, 255, 0.9);
border-radius: 14px;
padding: 1.1rem 1.25rem;
border: 1px solid rgba(255, 166, 77, 0.35);
box-shadow: 0 12px 26px rgba(255, 145, 0, 0.16);
display: grid;
gap: 0.85rem;
}
.summary-card.is-empty {
border-style: dashed;
box-shadow: none;
}
.summary-title {
font-weight: 600;
font-size: 1.05rem;
color: #7a3b09;
}
.summary-section {
display: grid;
gap: 0.45rem;
}
.summary-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(120px, 1fr));
gap: 0.4rem;
}
.summary-pill {
background: rgba(255, 245, 233, 0.95);
border: 1px solid rgba(255, 166, 77, 0.45);
border-radius: 999px;
padding: 0.35rem 0.75rem;
font-size: 0.85rem;
display: inline-flex;
align-items: center;
gap: 0.35rem;
color: #7a3b09;
justify-content: center;
}
.summary-pill strong {font-weight: 600;}
.summary-pill.integrator {
background: rgba(255, 231, 206, 0.95);
border-color: rgba(255, 160, 72, 0.65);
font-weight: 600;
}
.summary-divider {
border: none;
border-top: 1px dashed rgba(255, 166, 77, 0.4);
margin: 0.2rem 0;
}
.accordion-card {
--tw-border-opacity: 0.45;
border: 1px dashed rgba(255, 166, 77, 0.45) !important;
border-radius: 14px !important;
background: rgba(255, 255, 255, 0.88) !important;
}
.accordion-card > div:nth-child(1) {
font-weight: 600;
color: #7a3b09;
}
.churn-card {
margin-top: 0.75rem;
background: rgba(255, 255, 255, 0.85);
border-radius: 14px;
padding: 0.9rem 1rem 1.1rem;
border: 1px dashed rgba(255, 166, 77, 0.5);
box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.55);
}
.churn-title {
font-size: 0.92rem !important;
color: #8a450f;
margin-bottom: 0.55rem !important;
}
.gallery-card {
background: rgba(255, 255, 255, 0.9);
border-radius: 16px;
padding: 0.3rem 0.4rem;
border: 1px solid rgba(255, 166, 77, 0.28);
box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.25), 0 8px 18px rgba(255, 145, 0, 0.12);
}
.gallery-card [data-testid="upload-zone"] {
display: none !important;
}
.gallery-card .grid {
min-height: 180px;
}
.gallery-card img {
border-radius: 10px;
transition: transform 0.15s ease, box-shadow 0.15s ease;
}
.gallery-card img:hover {
transform: translateY(-2px);
box-shadow: 0 8px 14px rgba(255, 145, 0, 0.18);
}
.history-card {
background: rgba(255, 255, 255, 0.88);
border-radius: 16px;
padding: 0.9rem;
border: 1px solid rgba(255, 166, 77, 0.35);
box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.35), 0 10px 22px rgba(255, 145, 0, 0.15);
}
.plot-title {
color: #7a3b09 !important;
text-align: center;
font-weight: 600 !important;
margin-bottom: 0.45rem !important;
}
.history-placeholder {
text-align: center;
color: #8a450f;
font-size: 0.9rem;
margin-top: 0.5rem;
}
.value-chip {
background: rgba(255, 231, 206, 0.9);
border-radius: 999px;
padding: 0.1rem 0.55rem;
font-size: 0.82rem;
margin-left: 0.4rem;
color: #84400e;
}
@media (max-width: 768px) {
.summary-grid {grid-template-columns: repeat(auto-fit, minmax(110px, 1fr));}
.gallery-card .grid {min-height: 150px;}
.plot-card, .history-card {padding: 0.7rem;}
}
"""
def _prepare_gallery_images(samples: np.ndarray) -> List[np.ndarray]:
"""Convert normalized grayscale samples to RGB arrays for display."""
clipped = np.clip(samples, 0.0, 1.0)
uint8_imgs = (clipped * 255).astype(np.uint8)
if uint8_imgs.ndim == 3:
uint8_imgs = uint8_imgs[..., np.newaxis]
return [np.repeat(img, 3, axis=-1) for img in uint8_imgs]
def _make_history_plot(history_frames: np.ndarray) -> plt.Figure:
"""Render up to 10 frames from a sample trajectory in a single row."""
if history_frames.ndim == 4 and history_frames.shape[-1] == 1:
history_frames = history_frames[..., 0]
total_frames = history_frames.shape[0]
n_display = min(10, total_frames)
if n_display < 1:
raise ValueError("History sequence is empty.")
indices = np.linspace(0, total_frames - 1, n_display, dtype=int)
selected = history_frames[indices]
fig, axes = plt.subplots(1, n_display, figsize=(2.2 * n_display, 2.2))
if n_display == 1:
axes = np.array([axes])
for idx, ax in enumerate(np.atleast_1d(axes)):
ax.axis("off")
ax.imshow(selected[idx], cmap="gray")
ax.set_title(f"Step {indices[idx] + 1}", fontsize=8, color="#8a450f", pad=6)
fig.tight_layout()
return fig
def _format_summary(
*,
integrator_label: str,
n_steps: int,
history_len: int,
churn_params: Optional[dict],
) -> str:
sampler_grid = f"""
{integrator_label}
Steps {n_steps}
Samples {N_SAMPLES}
History {history_len}
""".strip()
churn_block = ""
if churn_params:
churn_block = f"""
Churning
Rate {churn_params['stochastic_churn_rate']:.3f}
Min {churn_params['churn_min']:.3f}
Max {churn_params['churn_max']:.3f}
Inflation {churn_params['noise_inflation_factor']:.4f}
""".strip()
return f"""
""".strip()
def show_history(evt: gr.SelectData, histories: Optional[List[np.ndarray]]):
"""Render the trajectory plot for the selected sample."""
if histories is None or len(histories) == 0:
return gr.update(value=None, visible=False), gr.update(
value="Click a digit above to explore its diffusion trajectory.",
visible=True,
)
index = 0
if evt is not None and evt.index is not None:
index = evt.index
if isinstance(index, (list, tuple)):
index = index[-1]
if not isinstance(index, (int, np.integer)) or index < 0 or index >= len(histories):
return gr.update(value=None, visible=False), gr.update(
value="Click a digit above to explore its diffusion trajectory.",
visible=True,
)
if histories[index] is None:
return gr.update(value=None, visible=False), gr.update(
value="Click a digit above to explore its diffusion trajectory.",
visible=True,
)
figure = _make_history_plot(histories[index])
return gr.update(value=figure, visible=True), gr.update(visible=False)
def generate(
integrator_label: str,
n_steps: int,
seed: int,
enable_churn: bool,
churn_rate: float,
churn_min_value: float,
churn_max_value: float,
noise_inflation_value: float,
):
"""Run sampling with the requested configuration and return UI artifacts."""
_, integrator_cfg = resolve_integrator(integrator_label)
n_steps = int(n_steps)
seed = int(seed)
if not (1 <= n_steps <= MAX_STEPS):
raise gr.Error(f"Number of steps must be between 1 and {MAX_STEPS}.")
supports_churn = integrator_cfg.get("supports_churn", False)
churn_params = None
if enable_churn:
if not supports_churn:
raise gr.Error("Stochastic churning is only available for deterministic integrators.")
churn_rate = float(churn_rate)
churn_min_value = float(churn_min_value)
churn_max_value = float(churn_max_value)
noise_inflation_value = float(noise_inflation_value)
if churn_rate < 0 or churn_rate > 1:
raise gr.Error("Churn rate must be within [0, 1].")
if churn_min_value < 0 or churn_max_value < 0:
raise gr.Error("Churn thresholds must be non-negative.")
if churn_max_value < churn_min_value:
raise gr.Error("Churn max threshold must be greater than or equal to churn min threshold.")
if noise_inflation_value < 1.0 or noise_inflation_value > MAX_NOISE_INFLATION:
raise gr.Error(f"Noise inflation factor must be within [1.0, {MAX_NOISE_INFLATION}].")
churn_params = {
"stochastic_churn_rate": churn_rate,
"churn_min": churn_min_value,
"churn_max": churn_max_value,
"noise_inflation_factor": noise_inflation_value,
}
denoiser_state, history = sample_batch(
integrator_label,
n_steps=n_steps,
n_samples=N_SAMPLES,
seed=seed,
keep_history=True,
churn_params=churn_params,
)
integrator_state = denoiser_state.integrator_state
samples = jax.device_get(integrator_state.position)
samples = np.asarray(samples)
if samples.ndim == 4 and samples.shape[-1] == 1:
samples = samples[..., 0]
# Diffusion models typically output data in [-1, 1]. Rescale to [0, 1].
samples = 0.5 * (samples + 1.0)
samples = np.clip(samples, 0.0, 1.0)
gallery_images = _prepare_gallery_images(samples)
sample_histories: Optional[List[np.ndarray]] = None
if history is not None:
history_np = jax.device_get(history)
history_np = np.asarray(history_np)
history_np = 0.5 * (history_np + 1.0)
history_np = np.clip(history_np, 0.0, 1.0)
sample_histories = [
history_np[:, sample_idx]
for sample_idx in range(history_np.shape[1])
]
if sample_histories is None:
sample_histories = []
history_len = int(history.shape[0]) if history is not None else 0
summary_html = _format_summary(
integrator_label=integrator_cfg["label"],
n_steps=n_steps,
history_len=history_len,
churn_params=churn_params,
)
gallery_update = gr.update(
value=gallery_images,
visible=True,
interactive=True,
height=220,
)
summary_update = gr.update(value=summary_html)
history_reset = gr.update(value=None, visible=False)
placeholder_update = gr.update(
value="Click a digit above to explore its diffusion trajectory.",
visible=True,
)
gr.Info(
f"Generated {N_SAMPLES} samples with {integrator_cfg['label']} ({n_steps} steps).",
duration=3,
)
return gallery_update, summary_update, history_reset, placeholder_update, sample_histories
def _handle_churn_toggle(integrator_label: str, enable_churn: bool):
"""Toggle churn controls visibility/open state based on integrator support."""
_, integrator_cfg = resolve_integrator(integrator_label)
supports = integrator_cfg.get("supports_churn", False)
enable_effective = supports and enable_churn
column_update = gr.update(visible=enable_effective)
accordion_update = gr.update(open=enable_effective)
return column_update, accordion_update
def _handle_integrator_change(integrator_label: str, enable_churn: bool):
"""Adjust checkbox interactivity and churn panel visibility when integrator changes."""
_, integrator_cfg = resolve_integrator(integrator_label)
supports = integrator_cfg.get("supports_churn", False)
effective_enable = enable_churn if supports else False
checkbox_update = gr.update(
interactive=supports,
value=effective_enable,
)
column_update, accordion_update = _handle_churn_toggle(integrator_label, effective_enable)
return checkbox_update, column_update, accordion_update
def _sync_churn_max(churn_min_value: float, current_max_value: float):
"""Ensure churn_max stays >= churn_min when churn_min changes."""
churn_min_value = float(churn_min_value)
current_max_value = float(current_max_value)
adjusted_max = current_max_value if current_max_value >= churn_min_value else churn_min_value
return gr.update(value=adjusted_max)
def _sync_churn_min(churn_max_value: float, current_min_value: float):
"""Ensure churn_min stays <= churn_max when churn_max changes."""
churn_max_value = float(churn_max_value)
current_min_value = float(current_min_value)
adjusted_min = current_min_value if current_min_value <= churn_max_value else churn_max_value
return gr.update(value=adjusted_min)
def build_ui() -> gr.Blocks:
"""Create the Gradio Blocks interface."""
available_labels = [spec["label"] for spec in INTEGRATORS.values()]
default_label = INTEGRATORS["ddim"]["label"]
with gr.Blocks(
title="Diffuse Integrator Explorer",
css=CUSTOM_CSS,
theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange"),
) as demo:
with gr.Row(elem_id="hero"):
gr.Image(
value=LOGO_VALUE,
show_label=False,
interactive=False,
elem_classes="hero-logo",
)
gr.Markdown(
"""
### Diffuse Integrator Explorer
Experiment with deterministic or stochastic samplers from the
[`diffuse-jax` library](https://diffuse.readthedocs.io/en/latest/index.html). Adjust the number of diffusion steps,
hit **Generate Samples**, and compare the five digits rendered
in the panel on the right.
""".strip(),
elem_classes="hero-copy",
)
with gr.Row():
with gr.Column(elem_classes="control-card"):
gr.Markdown("#### Sampling Controls", elem_classes="control-heading")
integrator_input = gr.Dropdown(
choices=available_labels,
value=default_label,
label="Integrator",
)
steps_input = gr.Slider(
minimum=1,
maximum=MAX_STEPS,
value=DEFAULT_STEPS,
step=1,
label="Number of steps",
)
seed_input = gr.Number(
value=0,
precision=0,
label="Random seed",
info="Use a different seed to explore new digits.",
)
with gr.Accordion("Churning controls", open=False, elem_classes="accordion-card") as churn_accordion:
churn_checkbox = gr.Checkbox(
value=False,
label="Enable stochastic churning",
info="Add controlled noise for deterministic integrators.",
)
with gr.Column(visible=False, elem_classes="churn-card") as churn_column:
gr.Markdown(
"**Churning parameters** ยท tweak how strongly noise is injected during sampling.",
elem_classes="churn-title",
)
churn_rate_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_CHURN_RATE,
step=0.01,
label="Churn rate",
)
churn_min_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_CHURN_MIN,
step=0.01,
label="Churn min threshold",
)
churn_max_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=DEFAULT_CHURN_MAX,
step=0.01,
label="Churn max threshold",
)
noise_inflation_input = gr.Slider(
minimum=1.0,
maximum=MAX_NOISE_INFLATION,
value=DEFAULT_NOISE_INFLATION,
step=0.001,
label="Noise inflation factor",
)
generate_button = gr.Button(
"Generate Samples",
variant="primary",
elem_classes="generate-button",
)
with gr.Column():
details = gr.HTML(
SUMMARY_PLACEHOLDER_HTML,
elem_classes="details-card",
container=False,
)
gr.Markdown("#### Generated Digit Strip", elem_classes="plot-title")
digit_strip = gr.Gallery(
columns=5,
allow_preview=False,
show_fullscreen_button=False,
object_fit="contain",
rows=1,
height=220,
show_label=False,
interactive=True,
elem_classes="gallery-card",
value=[],
container=False,
visible=False,
)
gr.Markdown("#### Sample Trajectory", elem_classes="plot-title")
history_plot = gr.Plot(elem_classes="history-card", show_label=False, visible=False)
history_placeholder = gr.Markdown(
"Generate samples, then click a digit above to explore its diffusion trajectory.",
elem_classes="history-placeholder",
visible=True,
container=False,
)
histories_state = gr.State([])
integrator_input.change(
fn=_handle_integrator_change,
inputs=[integrator_input, churn_checkbox],
outputs=[churn_checkbox, churn_column, churn_accordion],
)
churn_checkbox.change(
fn=_handle_churn_toggle,
inputs=[integrator_input, churn_checkbox],
outputs=[churn_column, churn_accordion],
)
churn_min_input.change(
fn=_sync_churn_max,
inputs=[churn_min_input, churn_max_input],
outputs=churn_max_input,
)
churn_max_input.change(
fn=_sync_churn_min,
inputs=[churn_max_input, churn_min_input],
outputs=churn_min_input,
)
generate_button.click(
fn=generate,
inputs=[
integrator_input,
steps_input,
seed_input,
churn_checkbox,
churn_rate_input,
churn_min_input,
churn_max_input,
noise_inflation_input,
],
outputs=[digit_strip, details, history_plot, history_placeholder, histories_state],
)
digit_strip.select(
fn=show_history,
inputs=[histories_state],
outputs=[history_plot, history_placeholder],
)
return demo
load_pipeline_assets()
if __name__ == "__main__":
demo = build_ui()
demo.queue().launch()