Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from typing import List, Optional | |
| os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") | |
| import jax | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| from pipeline import INTEGRATORS, load_pipeline_assets, resolve_integrator, sample_batch | |
| N_SAMPLES = 5 | |
| MAX_STEPS = 80 | |
| DEFAULT_STEPS = 20 | |
| ROOT_DIR = Path(__file__).parent | |
| LOGO_PATH = ROOT_DIR / "logo.png" | |
| LOGO_VALUE = str(LOGO_PATH) if LOGO_PATH.exists() else None | |
| DEFAULT_CHURN_RATE = 0.0 | |
| DEFAULT_CHURN_MIN = 0.0 | |
| DEFAULT_CHURN_MAX = 0.0 | |
| DEFAULT_NOISE_INFLATION = 1.0 | |
| MAX_NOISE_INFLATION = 1.02 | |
| SUMMARY_PLACEHOLDER_HTML = """ | |
| <div class="summary-card is-empty"> | |
| <div class="summary-title">Ready to sample</div> | |
| <p>Select an integrator, adjust the controls, then generate digits to inspect their trajectories.</p> | |
| </div> | |
| """.strip() | |
| CUSTOM_CSS = """ | |
| body {background: radial-gradient(circle at top left, #ffe8d5, #fff7f0 55%, #fdf1f8);} | |
| #hero { | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| gap: 1.5rem; | |
| background: rgba(255, 255, 255, 0.85); | |
| padding: 1.5rem 2rem; | |
| border-radius: 18px; | |
| box-shadow: 0 18px 35px rgba(255, 135, 0, 0.15); | |
| border: 1px solid rgba(255, 145, 0, 0.35); | |
| } | |
| .hero-logo img {max-width: 320px; width: 100%; object-fit: contain;} | |
| .hero-copy {font-size: 1.05rem !important; color: #7a3b09;} | |
| .control-card { | |
| background: rgba(255, 255, 255, 0.92); | |
| border-radius: 16px; | |
| padding: 1.25rem; | |
| border: 1px solid rgba(255, 166, 77, 0.35); | |
| box-shadow: 0 14px 30px rgba(255, 140, 0, 0.12); | |
| } | |
| .generate-button button { | |
| background: linear-gradient(135deg, #ff7e00, #ffb347); | |
| color: #fff; | |
| font-weight: 600; | |
| border-radius: 12px; | |
| box-shadow: 0 10px 20px rgba(255, 126, 0, 0.25); | |
| } | |
| .generate-button button:hover {filter: brightness(1.05);} | |
| .control-heading { | |
| font-weight: 600; | |
| color: #7a3b09; | |
| margin-bottom: 0.6rem !important; | |
| } | |
| .plot-card { | |
| background: rgba(255, 255, 255, 0.88); | |
| border-radius: 16px; | |
| padding: 1rem; | |
| border: 1px solid rgba(255, 166, 77, 0.35); | |
| box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.4), 0 12px 28px rgba(255, 145, 0, 0.18); | |
| } | |
| .details-card { | |
| border: none; | |
| padding: 0; | |
| } | |
| .summary-card { | |
| background: rgba(255, 255, 255, 0.9); | |
| border-radius: 14px; | |
| padding: 1.1rem 1.25rem; | |
| border: 1px solid rgba(255, 166, 77, 0.35); | |
| box-shadow: 0 12px 26px rgba(255, 145, 0, 0.16); | |
| display: grid; | |
| gap: 0.85rem; | |
| } | |
| .summary-card.is-empty { | |
| border-style: dashed; | |
| box-shadow: none; | |
| } | |
| .summary-title { | |
| font-weight: 600; | |
| font-size: 1.05rem; | |
| color: #7a3b09; | |
| } | |
| .summary-section { | |
| display: grid; | |
| gap: 0.45rem; | |
| } | |
| .summary-grid { | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(120px, 1fr)); | |
| gap: 0.4rem; | |
| } | |
| .summary-pill { | |
| background: rgba(255, 245, 233, 0.95); | |
| border: 1px solid rgba(255, 166, 77, 0.45); | |
| border-radius: 999px; | |
| padding: 0.35rem 0.75rem; | |
| font-size: 0.85rem; | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 0.35rem; | |
| color: #7a3b09; | |
| justify-content: center; | |
| } | |
| .summary-pill strong {font-weight: 600;} | |
| .summary-pill.integrator { | |
| background: rgba(255, 231, 206, 0.95); | |
| border-color: rgba(255, 160, 72, 0.65); | |
| font-weight: 600; | |
| } | |
| .summary-divider { | |
| border: none; | |
| border-top: 1px dashed rgba(255, 166, 77, 0.4); | |
| margin: 0.2rem 0; | |
| } | |
| .accordion-card { | |
| --tw-border-opacity: 0.45; | |
| border: 1px dashed rgba(255, 166, 77, 0.45) !important; | |
| border-radius: 14px !important; | |
| background: rgba(255, 255, 255, 0.88) !important; | |
| } | |
| .accordion-card > div:nth-child(1) { | |
| font-weight: 600; | |
| color: #7a3b09; | |
| } | |
| .churn-card { | |
| margin-top: 0.75rem; | |
| background: rgba(255, 255, 255, 0.85); | |
| border-radius: 14px; | |
| padding: 0.9rem 1rem 1.1rem; | |
| border: 1px dashed rgba(255, 166, 77, 0.5); | |
| box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.55); | |
| } | |
| .churn-title { | |
| font-size: 0.92rem !important; | |
| color: #8a450f; | |
| margin-bottom: 0.55rem !important; | |
| } | |
| .gallery-card { | |
| background: rgba(255, 255, 255, 0.9); | |
| border-radius: 16px; | |
| padding: 0.3rem 0.4rem; | |
| border: 1px solid rgba(255, 166, 77, 0.28); | |
| box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.25), 0 8px 18px rgba(255, 145, 0, 0.12); | |
| } | |
| .gallery-card [data-testid="upload-zone"] { | |
| display: none !important; | |
| } | |
| .gallery-card .grid { | |
| min-height: 180px; | |
| } | |
| .gallery-card img { | |
| border-radius: 10px; | |
| transition: transform 0.15s ease, box-shadow 0.15s ease; | |
| } | |
| .gallery-card img:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 8px 14px rgba(255, 145, 0, 0.18); | |
| } | |
| .history-card { | |
| background: rgba(255, 255, 255, 0.88); | |
| border-radius: 16px; | |
| padding: 0.9rem; | |
| border: 1px solid rgba(255, 166, 77, 0.35); | |
| box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.35), 0 10px 22px rgba(255, 145, 0, 0.15); | |
| } | |
| .plot-title { | |
| color: #7a3b09 !important; | |
| text-align: center; | |
| font-weight: 600 !important; | |
| margin-bottom: 0.45rem !important; | |
| } | |
| .history-placeholder { | |
| text-align: center; | |
| color: #8a450f; | |
| font-size: 0.9rem; | |
| margin-top: 0.5rem; | |
| } | |
| .value-chip { | |
| background: rgba(255, 231, 206, 0.9); | |
| border-radius: 999px; | |
| padding: 0.1rem 0.55rem; | |
| font-size: 0.82rem; | |
| margin-left: 0.4rem; | |
| color: #84400e; | |
| } | |
| @media (max-width: 768px) { | |
| .summary-grid {grid-template-columns: repeat(auto-fit, minmax(110px, 1fr));} | |
| .gallery-card .grid {min-height: 150px;} | |
| .plot-card, .history-card {padding: 0.7rem;} | |
| } | |
| """ | |
| def _prepare_gallery_images(samples: np.ndarray) -> List[np.ndarray]: | |
| """Convert normalized grayscale samples to RGB arrays for display.""" | |
| clipped = np.clip(samples, 0.0, 1.0) | |
| uint8_imgs = (clipped * 255).astype(np.uint8) | |
| if uint8_imgs.ndim == 3: | |
| uint8_imgs = uint8_imgs[..., np.newaxis] | |
| return [np.repeat(img, 3, axis=-1) for img in uint8_imgs] | |
| def _make_history_plot(history_frames: np.ndarray) -> plt.Figure: | |
| """Render up to 10 frames from a sample trajectory in a single row.""" | |
| if history_frames.ndim == 4 and history_frames.shape[-1] == 1: | |
| history_frames = history_frames[..., 0] | |
| total_frames = history_frames.shape[0] | |
| n_display = min(10, total_frames) | |
| if n_display < 1: | |
| raise ValueError("History sequence is empty.") | |
| indices = np.linspace(0, total_frames - 1, n_display, dtype=int) | |
| selected = history_frames[indices] | |
| fig, axes = plt.subplots(1, n_display, figsize=(2.2 * n_display, 2.2)) | |
| if n_display == 1: | |
| axes = np.array([axes]) | |
| for idx, ax in enumerate(np.atleast_1d(axes)): | |
| ax.axis("off") | |
| ax.imshow(selected[idx], cmap="gray") | |
| ax.set_title(f"Step {indices[idx] + 1}", fontsize=8, color="#8a450f", pad=6) | |
| fig.tight_layout() | |
| return fig | |
| def _format_summary( | |
| *, | |
| integrator_label: str, | |
| n_steps: int, | |
| history_len: int, | |
| churn_params: Optional[dict], | |
| ) -> str: | |
| sampler_grid = f""" | |
| <div class="summary-grid"> | |
| <span class="summary-pill integrator">{integrator_label}</span> | |
| <span class="summary-pill">Steps <strong>{n_steps}</strong></span> | |
| <span class="summary-pill">Samples <strong>{N_SAMPLES}</strong></span> | |
| <span class="summary-pill">History <strong>{history_len}</strong></span> | |
| </div> | |
| """.strip() | |
| churn_block = "" | |
| if churn_params: | |
| churn_block = f""" | |
| <hr class="summary-divider" /> | |
| <div class="summary-section"> | |
| <div class="summary-title">Churning</div> | |
| <div class="summary-grid"> | |
| <span class="summary-pill">Rate <strong>{churn_params['stochastic_churn_rate']:.3f}</strong></span> | |
| <span class="summary-pill">Min <strong>{churn_params['churn_min']:.3f}</strong></span> | |
| <span class="summary-pill">Max <strong>{churn_params['churn_max']:.3f}</strong></span> | |
| <span class="summary-pill">Inflation <strong>{churn_params['noise_inflation_factor']:.4f}</strong></span> | |
| </div> | |
| </div> | |
| """.strip() | |
| return f""" | |
| <div class="summary-card"> | |
| <div class="summary-section"> | |
| <div class="summary-title">Sampler</div> | |
| {sampler_grid} | |
| </div> | |
| {churn_block} | |
| </div> | |
| """.strip() | |
| def show_history(evt: gr.SelectData, histories: Optional[List[np.ndarray]]): | |
| """Render the trajectory plot for the selected sample.""" | |
| if histories is None or len(histories) == 0: | |
| return gr.update(value=None, visible=False), gr.update( | |
| value="Click a digit above to explore its diffusion trajectory.", | |
| visible=True, | |
| ) | |
| index = 0 | |
| if evt is not None and evt.index is not None: | |
| index = evt.index | |
| if isinstance(index, (list, tuple)): | |
| index = index[-1] | |
| if not isinstance(index, (int, np.integer)) or index < 0 or index >= len(histories): | |
| return gr.update(value=None, visible=False), gr.update( | |
| value="Click a digit above to explore its diffusion trajectory.", | |
| visible=True, | |
| ) | |
| if histories[index] is None: | |
| return gr.update(value=None, visible=False), gr.update( | |
| value="Click a digit above to explore its diffusion trajectory.", | |
| visible=True, | |
| ) | |
| figure = _make_history_plot(histories[index]) | |
| return gr.update(value=figure, visible=True), gr.update(visible=False) | |
| def generate( | |
| integrator_label: str, | |
| n_steps: int, | |
| seed: int, | |
| enable_churn: bool, | |
| churn_rate: float, | |
| churn_min_value: float, | |
| churn_max_value: float, | |
| noise_inflation_value: float, | |
| ): | |
| """Run sampling with the requested configuration and return UI artifacts.""" | |
| _, integrator_cfg = resolve_integrator(integrator_label) | |
| n_steps = int(n_steps) | |
| seed = int(seed) | |
| if not (1 <= n_steps <= MAX_STEPS): | |
| raise gr.Error(f"Number of steps must be between 1 and {MAX_STEPS}.") | |
| supports_churn = integrator_cfg.get("supports_churn", False) | |
| churn_params = None | |
| if enable_churn: | |
| if not supports_churn: | |
| raise gr.Error("Stochastic churning is only available for deterministic integrators.") | |
| churn_rate = float(churn_rate) | |
| churn_min_value = float(churn_min_value) | |
| churn_max_value = float(churn_max_value) | |
| noise_inflation_value = float(noise_inflation_value) | |
| if churn_rate < 0 or churn_rate > 1: | |
| raise gr.Error("Churn rate must be within [0, 1].") | |
| if churn_min_value < 0 or churn_max_value < 0: | |
| raise gr.Error("Churn thresholds must be non-negative.") | |
| if churn_max_value < churn_min_value: | |
| raise gr.Error("Churn max threshold must be greater than or equal to churn min threshold.") | |
| if noise_inflation_value < 1.0 or noise_inflation_value > MAX_NOISE_INFLATION: | |
| raise gr.Error(f"Noise inflation factor must be within [1.0, {MAX_NOISE_INFLATION}].") | |
| churn_params = { | |
| "stochastic_churn_rate": churn_rate, | |
| "churn_min": churn_min_value, | |
| "churn_max": churn_max_value, | |
| "noise_inflation_factor": noise_inflation_value, | |
| } | |
| denoiser_state, history = sample_batch( | |
| integrator_label, | |
| n_steps=n_steps, | |
| n_samples=N_SAMPLES, | |
| seed=seed, | |
| keep_history=True, | |
| churn_params=churn_params, | |
| ) | |
| integrator_state = denoiser_state.integrator_state | |
| samples = jax.device_get(integrator_state.position) | |
| samples = np.asarray(samples) | |
| if samples.ndim == 4 and samples.shape[-1] == 1: | |
| samples = samples[..., 0] | |
| # Diffusion models typically output data in [-1, 1]. Rescale to [0, 1]. | |
| samples = 0.5 * (samples + 1.0) | |
| samples = np.clip(samples, 0.0, 1.0) | |
| gallery_images = _prepare_gallery_images(samples) | |
| sample_histories: Optional[List[np.ndarray]] = None | |
| if history is not None: | |
| history_np = jax.device_get(history) | |
| history_np = np.asarray(history_np) | |
| history_np = 0.5 * (history_np + 1.0) | |
| history_np = np.clip(history_np, 0.0, 1.0) | |
| sample_histories = [ | |
| history_np[:, sample_idx] | |
| for sample_idx in range(history_np.shape[1]) | |
| ] | |
| if sample_histories is None: | |
| sample_histories = [] | |
| history_len = int(history.shape[0]) if history is not None else 0 | |
| summary_html = _format_summary( | |
| integrator_label=integrator_cfg["label"], | |
| n_steps=n_steps, | |
| history_len=history_len, | |
| churn_params=churn_params, | |
| ) | |
| gallery_update = gr.update( | |
| value=gallery_images, | |
| visible=True, | |
| interactive=True, | |
| height=220, | |
| ) | |
| summary_update = gr.update(value=summary_html) | |
| history_reset = gr.update(value=None, visible=False) | |
| placeholder_update = gr.update( | |
| value="Click a digit above to explore its diffusion trajectory.", | |
| visible=True, | |
| ) | |
| gr.Info( | |
| f"Generated {N_SAMPLES} samples with {integrator_cfg['label']} ({n_steps} steps).", | |
| duration=3, | |
| ) | |
| return gallery_update, summary_update, history_reset, placeholder_update, sample_histories | |
| def _handle_churn_toggle(integrator_label: str, enable_churn: bool): | |
| """Toggle churn controls visibility/open state based on integrator support.""" | |
| _, integrator_cfg = resolve_integrator(integrator_label) | |
| supports = integrator_cfg.get("supports_churn", False) | |
| enable_effective = supports and enable_churn | |
| column_update = gr.update(visible=enable_effective) | |
| accordion_update = gr.update(open=enable_effective) | |
| return column_update, accordion_update | |
| def _handle_integrator_change(integrator_label: str, enable_churn: bool): | |
| """Adjust checkbox interactivity and churn panel visibility when integrator changes.""" | |
| _, integrator_cfg = resolve_integrator(integrator_label) | |
| supports = integrator_cfg.get("supports_churn", False) | |
| effective_enable = enable_churn if supports else False | |
| checkbox_update = gr.update( | |
| interactive=supports, | |
| value=effective_enable, | |
| ) | |
| column_update, accordion_update = _handle_churn_toggle(integrator_label, effective_enable) | |
| return checkbox_update, column_update, accordion_update | |
| def _sync_churn_max(churn_min_value: float, current_max_value: float): | |
| """Ensure churn_max stays >= churn_min when churn_min changes.""" | |
| churn_min_value = float(churn_min_value) | |
| current_max_value = float(current_max_value) | |
| adjusted_max = current_max_value if current_max_value >= churn_min_value else churn_min_value | |
| return gr.update(value=adjusted_max) | |
| def _sync_churn_min(churn_max_value: float, current_min_value: float): | |
| """Ensure churn_min stays <= churn_max when churn_max changes.""" | |
| churn_max_value = float(churn_max_value) | |
| current_min_value = float(current_min_value) | |
| adjusted_min = current_min_value if current_min_value <= churn_max_value else churn_max_value | |
| return gr.update(value=adjusted_min) | |
| def build_ui() -> gr.Blocks: | |
| """Create the Gradio Blocks interface.""" | |
| available_labels = [spec["label"] for spec in INTEGRATORS.values()] | |
| default_label = INTEGRATORS["ddim"]["label"] | |
| with gr.Blocks( | |
| title="Diffuse Integrator Explorer", | |
| css=CUSTOM_CSS, | |
| theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange"), | |
| ) as demo: | |
| with gr.Row(elem_id="hero"): | |
| gr.Image( | |
| value=LOGO_VALUE, | |
| show_label=False, | |
| interactive=False, | |
| elem_classes="hero-logo", | |
| ) | |
| gr.Markdown( | |
| """ | |
| ### Diffuse Integrator Explorer | |
| Experiment with deterministic or stochastic samplers from the | |
| [`diffuse-jax` library](https://diffuse.readthedocs.io/en/latest/index.html). Adjust the number of diffusion steps, | |
| hit **Generate Samples**, and compare the five digits rendered | |
| in the panel on the right. | |
| """.strip(), | |
| elem_classes="hero-copy", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(elem_classes="control-card"): | |
| gr.Markdown("#### Sampling Controls", elem_classes="control-heading") | |
| integrator_input = gr.Dropdown( | |
| choices=available_labels, | |
| value=default_label, | |
| label="Integrator", | |
| ) | |
| steps_input = gr.Slider( | |
| minimum=1, | |
| maximum=MAX_STEPS, | |
| value=DEFAULT_STEPS, | |
| step=1, | |
| label="Number of steps", | |
| ) | |
| seed_input = gr.Number( | |
| value=0, | |
| precision=0, | |
| label="Random seed", | |
| info="Use a different seed to explore new digits.", | |
| ) | |
| with gr.Accordion("Churning controls", open=False, elem_classes="accordion-card") as churn_accordion: | |
| churn_checkbox = gr.Checkbox( | |
| value=False, | |
| label="Enable stochastic churning", | |
| info="Add controlled noise for deterministic integrators.", | |
| ) | |
| with gr.Column(visible=False, elem_classes="churn-card") as churn_column: | |
| gr.Markdown( | |
| "**Churning parameters** · tweak how strongly noise is injected during sampling.", | |
| elem_classes="churn-title", | |
| ) | |
| churn_rate_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_CHURN_RATE, | |
| step=0.01, | |
| label="Churn rate", | |
| ) | |
| churn_min_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_CHURN_MIN, | |
| step=0.01, | |
| label="Churn min threshold", | |
| ) | |
| churn_max_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_CHURN_MAX, | |
| step=0.01, | |
| label="Churn max threshold", | |
| ) | |
| noise_inflation_input = gr.Slider( | |
| minimum=1.0, | |
| maximum=MAX_NOISE_INFLATION, | |
| value=DEFAULT_NOISE_INFLATION, | |
| step=0.001, | |
| label="Noise inflation factor", | |
| ) | |
| generate_button = gr.Button( | |
| "Generate Samples", | |
| variant="primary", | |
| elem_classes="generate-button", | |
| ) | |
| with gr.Column(): | |
| details = gr.HTML( | |
| SUMMARY_PLACEHOLDER_HTML, | |
| elem_classes="details-card", | |
| container=False, | |
| ) | |
| gr.Markdown("#### Generated Digit Strip", elem_classes="plot-title") | |
| digit_strip = gr.Gallery( | |
| columns=5, | |
| allow_preview=False, | |
| show_fullscreen_button=False, | |
| object_fit="contain", | |
| rows=1, | |
| height=220, | |
| show_label=False, | |
| interactive=True, | |
| elem_classes="gallery-card", | |
| value=[], | |
| container=False, | |
| visible=False, | |
| ) | |
| gr.Markdown("#### Sample Trajectory", elem_classes="plot-title") | |
| history_plot = gr.Plot(elem_classes="history-card", show_label=False, visible=False) | |
| history_placeholder = gr.Markdown( | |
| "Generate samples, then click a digit above to explore its diffusion trajectory.", | |
| elem_classes="history-placeholder", | |
| visible=True, | |
| container=False, | |
| ) | |
| histories_state = gr.State([]) | |
| integrator_input.change( | |
| fn=_handle_integrator_change, | |
| inputs=[integrator_input, churn_checkbox], | |
| outputs=[churn_checkbox, churn_column, churn_accordion], | |
| ) | |
| churn_checkbox.change( | |
| fn=_handle_churn_toggle, | |
| inputs=[integrator_input, churn_checkbox], | |
| outputs=[churn_column, churn_accordion], | |
| ) | |
| churn_min_input.change( | |
| fn=_sync_churn_max, | |
| inputs=[churn_min_input, churn_max_input], | |
| outputs=churn_max_input, | |
| ) | |
| churn_max_input.change( | |
| fn=_sync_churn_min, | |
| inputs=[churn_max_input, churn_min_input], | |
| outputs=churn_min_input, | |
| ) | |
| generate_button.click( | |
| fn=generate, | |
| inputs=[ | |
| integrator_input, | |
| steps_input, | |
| seed_input, | |
| churn_checkbox, | |
| churn_rate_input, | |
| churn_min_input, | |
| churn_max_input, | |
| noise_inflation_input, | |
| ], | |
| outputs=[digit_strip, details, history_plot, history_placeholder, histories_state], | |
| ) | |
| digit_strip.select( | |
| fn=show_history, | |
| inputs=[histories_state], | |
| outputs=[history_plot, history_placeholder], | |
| ) | |
| return demo | |
| load_pipeline_assets() | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.queue().launch() | |