import os from pathlib import Path from typing import List, Optional os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") import jax import numpy as np import gradio as gr import matplotlib.pyplot as plt from pipeline import INTEGRATORS, load_pipeline_assets, resolve_integrator, sample_batch N_SAMPLES = 5 MAX_STEPS = 80 DEFAULT_STEPS = 20 ROOT_DIR = Path(__file__).parent LOGO_PATH = ROOT_DIR / "logo.png" LOGO_VALUE = str(LOGO_PATH) if LOGO_PATH.exists() else None DEFAULT_CHURN_RATE = 0.0 DEFAULT_CHURN_MIN = 0.0 DEFAULT_CHURN_MAX = 0.0 DEFAULT_NOISE_INFLATION = 1.0 MAX_NOISE_INFLATION = 1.02 SUMMARY_PLACEHOLDER_HTML = """
Ready to sample

Select an integrator, adjust the controls, then generate digits to inspect their trajectories.

""".strip() CUSTOM_CSS = """ body {background: radial-gradient(circle at top left, #ffe8d5, #fff7f0 55%, #fdf1f8);} #hero { display: flex; align-items: center; justify-content: center; gap: 1.5rem; background: rgba(255, 255, 255, 0.85); padding: 1.5rem 2rem; border-radius: 18px; box-shadow: 0 18px 35px rgba(255, 135, 0, 0.15); border: 1px solid rgba(255, 145, 0, 0.35); } .hero-logo img {max-width: 320px; width: 100%; object-fit: contain;} .hero-copy {font-size: 1.05rem !important; color: #7a3b09;} .control-card { background: rgba(255, 255, 255, 0.92); border-radius: 16px; padding: 1.25rem; border: 1px solid rgba(255, 166, 77, 0.35); box-shadow: 0 14px 30px rgba(255, 140, 0, 0.12); } .generate-button button { background: linear-gradient(135deg, #ff7e00, #ffb347); color: #fff; font-weight: 600; border-radius: 12px; box-shadow: 0 10px 20px rgba(255, 126, 0, 0.25); } .generate-button button:hover {filter: brightness(1.05);} .control-heading { font-weight: 600; color: #7a3b09; margin-bottom: 0.6rem !important; } .plot-card { background: rgba(255, 255, 255, 0.88); border-radius: 16px; padding: 1rem; border: 1px solid rgba(255, 166, 77, 0.35); box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.4), 0 12px 28px rgba(255, 145, 0, 0.18); } .details-card { border: none; padding: 0; } .summary-card { background: rgba(255, 255, 255, 0.9); border-radius: 14px; padding: 1.1rem 1.25rem; border: 1px solid rgba(255, 166, 77, 0.35); box-shadow: 0 12px 26px rgba(255, 145, 0, 0.16); display: grid; gap: 0.85rem; } .summary-card.is-empty { border-style: dashed; box-shadow: none; } .summary-title { font-weight: 600; font-size: 1.05rem; color: #7a3b09; } .summary-section { display: grid; gap: 0.45rem; } .summary-grid { display: grid; grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)); gap: 0.4rem; } .summary-pill { background: rgba(255, 245, 233, 0.95); border: 1px solid rgba(255, 166, 77, 0.45); border-radius: 999px; padding: 0.35rem 0.75rem; font-size: 0.85rem; display: inline-flex; align-items: center; gap: 0.35rem; color: #7a3b09; justify-content: center; } .summary-pill strong {font-weight: 600;} .summary-pill.integrator { background: rgba(255, 231, 206, 0.95); border-color: rgba(255, 160, 72, 0.65); font-weight: 600; } .summary-divider { border: none; border-top: 1px dashed rgba(255, 166, 77, 0.4); margin: 0.2rem 0; } .accordion-card { --tw-border-opacity: 0.45; border: 1px dashed rgba(255, 166, 77, 0.45) !important; border-radius: 14px !important; background: rgba(255, 255, 255, 0.88) !important; } .accordion-card > div:nth-child(1) { font-weight: 600; color: #7a3b09; } .churn-card { margin-top: 0.75rem; background: rgba(255, 255, 255, 0.85); border-radius: 14px; padding: 0.9rem 1rem 1.1rem; border: 1px dashed rgba(255, 166, 77, 0.5); box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.55); } .churn-title { font-size: 0.92rem !important; color: #8a450f; margin-bottom: 0.55rem !important; } .gallery-card { background: rgba(255, 255, 255, 0.9); border-radius: 16px; padding: 0.3rem 0.4rem; border: 1px solid rgba(255, 166, 77, 0.28); box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.25), 0 8px 18px rgba(255, 145, 0, 0.12); } .gallery-card [data-testid="upload-zone"] { display: none !important; } .gallery-card .grid { min-height: 180px; } .gallery-card img { border-radius: 10px; transition: transform 0.15s ease, box-shadow 0.15s ease; } .gallery-card img:hover { transform: translateY(-2px); box-shadow: 0 8px 14px rgba(255, 145, 0, 0.18); } .history-card { background: rgba(255, 255, 255, 0.88); border-radius: 16px; padding: 0.9rem; border: 1px solid rgba(255, 166, 77, 0.35); box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.35), 0 10px 22px rgba(255, 145, 0, 0.15); } .plot-title { color: #7a3b09 !important; text-align: center; font-weight: 600 !important; margin-bottom: 0.45rem !important; } .history-placeholder { text-align: center; color: #8a450f; font-size: 0.9rem; margin-top: 0.5rem; } .value-chip { background: rgba(255, 231, 206, 0.9); border-radius: 999px; padding: 0.1rem 0.55rem; font-size: 0.82rem; margin-left: 0.4rem; color: #84400e; } @media (max-width: 768px) { .summary-grid {grid-template-columns: repeat(auto-fit, minmax(110px, 1fr));} .gallery-card .grid {min-height: 150px;} .plot-card, .history-card {padding: 0.7rem;} } """ def _prepare_gallery_images(samples: np.ndarray) -> List[np.ndarray]: """Convert normalized grayscale samples to RGB arrays for display.""" clipped = np.clip(samples, 0.0, 1.0) uint8_imgs = (clipped * 255).astype(np.uint8) if uint8_imgs.ndim == 3: uint8_imgs = uint8_imgs[..., np.newaxis] return [np.repeat(img, 3, axis=-1) for img in uint8_imgs] def _make_history_plot(history_frames: np.ndarray) -> plt.Figure: """Render up to 10 frames from a sample trajectory in a single row.""" if history_frames.ndim == 4 and history_frames.shape[-1] == 1: history_frames = history_frames[..., 0] total_frames = history_frames.shape[0] n_display = min(10, total_frames) if n_display < 1: raise ValueError("History sequence is empty.") indices = np.linspace(0, total_frames - 1, n_display, dtype=int) selected = history_frames[indices] fig, axes = plt.subplots(1, n_display, figsize=(2.2 * n_display, 2.2)) if n_display == 1: axes = np.array([axes]) for idx, ax in enumerate(np.atleast_1d(axes)): ax.axis("off") ax.imshow(selected[idx], cmap="gray") ax.set_title(f"Step {indices[idx] + 1}", fontsize=8, color="#8a450f", pad=6) fig.tight_layout() return fig def _format_summary( *, integrator_label: str, n_steps: int, history_len: int, churn_params: Optional[dict], ) -> str: sampler_grid = f"""
{integrator_label} Steps {n_steps} Samples {N_SAMPLES} History {history_len}
""".strip() churn_block = "" if churn_params: churn_block = f"""
Churning
Rate {churn_params['stochastic_churn_rate']:.3f} Min {churn_params['churn_min']:.3f} Max {churn_params['churn_max']:.3f} Inflation {churn_params['noise_inflation_factor']:.4f}
""".strip() return f"""
Sampler
{sampler_grid}
{churn_block}
""".strip() def show_history(evt: gr.SelectData, histories: Optional[List[np.ndarray]]): """Render the trajectory plot for the selected sample.""" if histories is None or len(histories) == 0: return gr.update(value=None, visible=False), gr.update( value="Click a digit above to explore its diffusion trajectory.", visible=True, ) index = 0 if evt is not None and evt.index is not None: index = evt.index if isinstance(index, (list, tuple)): index = index[-1] if not isinstance(index, (int, np.integer)) or index < 0 or index >= len(histories): return gr.update(value=None, visible=False), gr.update( value="Click a digit above to explore its diffusion trajectory.", visible=True, ) if histories[index] is None: return gr.update(value=None, visible=False), gr.update( value="Click a digit above to explore its diffusion trajectory.", visible=True, ) figure = _make_history_plot(histories[index]) return gr.update(value=figure, visible=True), gr.update(visible=False) def generate( integrator_label: str, n_steps: int, seed: int, enable_churn: bool, churn_rate: float, churn_min_value: float, churn_max_value: float, noise_inflation_value: float, ): """Run sampling with the requested configuration and return UI artifacts.""" _, integrator_cfg = resolve_integrator(integrator_label) n_steps = int(n_steps) seed = int(seed) if not (1 <= n_steps <= MAX_STEPS): raise gr.Error(f"Number of steps must be between 1 and {MAX_STEPS}.") supports_churn = integrator_cfg.get("supports_churn", False) churn_params = None if enable_churn: if not supports_churn: raise gr.Error("Stochastic churning is only available for deterministic integrators.") churn_rate = float(churn_rate) churn_min_value = float(churn_min_value) churn_max_value = float(churn_max_value) noise_inflation_value = float(noise_inflation_value) if churn_rate < 0 or churn_rate > 1: raise gr.Error("Churn rate must be within [0, 1].") if churn_min_value < 0 or churn_max_value < 0: raise gr.Error("Churn thresholds must be non-negative.") if churn_max_value < churn_min_value: raise gr.Error("Churn max threshold must be greater than or equal to churn min threshold.") if noise_inflation_value < 1.0 or noise_inflation_value > MAX_NOISE_INFLATION: raise gr.Error(f"Noise inflation factor must be within [1.0, {MAX_NOISE_INFLATION}].") churn_params = { "stochastic_churn_rate": churn_rate, "churn_min": churn_min_value, "churn_max": churn_max_value, "noise_inflation_factor": noise_inflation_value, } denoiser_state, history = sample_batch( integrator_label, n_steps=n_steps, n_samples=N_SAMPLES, seed=seed, keep_history=True, churn_params=churn_params, ) integrator_state = denoiser_state.integrator_state samples = jax.device_get(integrator_state.position) samples = np.asarray(samples) if samples.ndim == 4 and samples.shape[-1] == 1: samples = samples[..., 0] # Diffusion models typically output data in [-1, 1]. Rescale to [0, 1]. samples = 0.5 * (samples + 1.0) samples = np.clip(samples, 0.0, 1.0) gallery_images = _prepare_gallery_images(samples) sample_histories: Optional[List[np.ndarray]] = None if history is not None: history_np = jax.device_get(history) history_np = np.asarray(history_np) history_np = 0.5 * (history_np + 1.0) history_np = np.clip(history_np, 0.0, 1.0) sample_histories = [ history_np[:, sample_idx] for sample_idx in range(history_np.shape[1]) ] if sample_histories is None: sample_histories = [] history_len = int(history.shape[0]) if history is not None else 0 summary_html = _format_summary( integrator_label=integrator_cfg["label"], n_steps=n_steps, history_len=history_len, churn_params=churn_params, ) gallery_update = gr.update( value=gallery_images, visible=True, interactive=True, height=220, ) summary_update = gr.update(value=summary_html) history_reset = gr.update(value=None, visible=False) placeholder_update = gr.update( value="Click a digit above to explore its diffusion trajectory.", visible=True, ) gr.Info( f"Generated {N_SAMPLES} samples with {integrator_cfg['label']} ({n_steps} steps).", duration=3, ) return gallery_update, summary_update, history_reset, placeholder_update, sample_histories def _handle_churn_toggle(integrator_label: str, enable_churn: bool): """Toggle churn controls visibility/open state based on integrator support.""" _, integrator_cfg = resolve_integrator(integrator_label) supports = integrator_cfg.get("supports_churn", False) enable_effective = supports and enable_churn column_update = gr.update(visible=enable_effective) accordion_update = gr.update(open=enable_effective) return column_update, accordion_update def _handle_integrator_change(integrator_label: str, enable_churn: bool): """Adjust checkbox interactivity and churn panel visibility when integrator changes.""" _, integrator_cfg = resolve_integrator(integrator_label) supports = integrator_cfg.get("supports_churn", False) effective_enable = enable_churn if supports else False checkbox_update = gr.update( interactive=supports, value=effective_enable, ) column_update, accordion_update = _handle_churn_toggle(integrator_label, effective_enable) return checkbox_update, column_update, accordion_update def _sync_churn_max(churn_min_value: float, current_max_value: float): """Ensure churn_max stays >= churn_min when churn_min changes.""" churn_min_value = float(churn_min_value) current_max_value = float(current_max_value) adjusted_max = current_max_value if current_max_value >= churn_min_value else churn_min_value return gr.update(value=adjusted_max) def _sync_churn_min(churn_max_value: float, current_min_value: float): """Ensure churn_min stays <= churn_max when churn_max changes.""" churn_max_value = float(churn_max_value) current_min_value = float(current_min_value) adjusted_min = current_min_value if current_min_value <= churn_max_value else churn_max_value return gr.update(value=adjusted_min) def build_ui() -> gr.Blocks: """Create the Gradio Blocks interface.""" available_labels = [spec["label"] for spec in INTEGRATORS.values()] default_label = INTEGRATORS["ddim"]["label"] with gr.Blocks( title="Diffuse Integrator Explorer", css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange"), ) as demo: with gr.Row(elem_id="hero"): gr.Image( value=LOGO_VALUE, show_label=False, interactive=False, elem_classes="hero-logo", ) gr.Markdown( """ ### Diffuse Integrator Explorer Experiment with deterministic or stochastic samplers from the [`diffuse-jax` library](https://diffuse.readthedocs.io/en/latest/index.html). Adjust the number of diffusion steps, hit **Generate Samples**, and compare the five digits rendered in the panel on the right. """.strip(), elem_classes="hero-copy", ) with gr.Row(): with gr.Column(elem_classes="control-card"): gr.Markdown("#### Sampling Controls", elem_classes="control-heading") integrator_input = gr.Dropdown( choices=available_labels, value=default_label, label="Integrator", ) steps_input = gr.Slider( minimum=1, maximum=MAX_STEPS, value=DEFAULT_STEPS, step=1, label="Number of steps", ) seed_input = gr.Number( value=0, precision=0, label="Random seed", info="Use a different seed to explore new digits.", ) with gr.Accordion("Churning controls", open=False, elem_classes="accordion-card") as churn_accordion: churn_checkbox = gr.Checkbox( value=False, label="Enable stochastic churning", info="Add controlled noise for deterministic integrators.", ) with gr.Column(visible=False, elem_classes="churn-card") as churn_column: gr.Markdown( "**Churning parameters** ยท tweak how strongly noise is injected during sampling.", elem_classes="churn-title", ) churn_rate_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_CHURN_RATE, step=0.01, label="Churn rate", ) churn_min_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_CHURN_MIN, step=0.01, label="Churn min threshold", ) churn_max_input = gr.Slider( minimum=0.0, maximum=1.0, value=DEFAULT_CHURN_MAX, step=0.01, label="Churn max threshold", ) noise_inflation_input = gr.Slider( minimum=1.0, maximum=MAX_NOISE_INFLATION, value=DEFAULT_NOISE_INFLATION, step=0.001, label="Noise inflation factor", ) generate_button = gr.Button( "Generate Samples", variant="primary", elem_classes="generate-button", ) with gr.Column(): details = gr.HTML( SUMMARY_PLACEHOLDER_HTML, elem_classes="details-card", container=False, ) gr.Markdown("#### Generated Digit Strip", elem_classes="plot-title") digit_strip = gr.Gallery( columns=5, allow_preview=False, show_fullscreen_button=False, object_fit="contain", rows=1, height=220, show_label=False, interactive=True, elem_classes="gallery-card", value=[], container=False, visible=False, ) gr.Markdown("#### Sample Trajectory", elem_classes="plot-title") history_plot = gr.Plot(elem_classes="history-card", show_label=False, visible=False) history_placeholder = gr.Markdown( "Generate samples, then click a digit above to explore its diffusion trajectory.", elem_classes="history-placeholder", visible=True, container=False, ) histories_state = gr.State([]) integrator_input.change( fn=_handle_integrator_change, inputs=[integrator_input, churn_checkbox], outputs=[churn_checkbox, churn_column, churn_accordion], ) churn_checkbox.change( fn=_handle_churn_toggle, inputs=[integrator_input, churn_checkbox], outputs=[churn_column, churn_accordion], ) churn_min_input.change( fn=_sync_churn_max, inputs=[churn_min_input, churn_max_input], outputs=churn_max_input, ) churn_max_input.change( fn=_sync_churn_min, inputs=[churn_max_input, churn_min_input], outputs=churn_min_input, ) generate_button.click( fn=generate, inputs=[ integrator_input, steps_input, seed_input, churn_checkbox, churn_rate_input, churn_min_input, churn_max_input, noise_inflation_input, ], outputs=[digit_strip, details, history_plot, history_placeholder, histories_state], ) digit_strip.select( fn=show_history, inputs=[histories_state], outputs=[history_plot, history_placeholder], ) return demo load_pipeline_assets() if __name__ == "__main__": demo = build_ui() demo.queue().launch()