Geoffroy38000 commited on
Commit
67d59a3
·
1 Parent(s): 32b646d

Polish diffusion UI and churning controls

Browse files
Files changed (4) hide show
  1. app.py +634 -0
  2. pipeline.py +158 -0
  3. requirements.txt +7 -0
  4. script.py +22 -0
app.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
6
+
7
+ import jax
8
+ import numpy as np
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+
12
+ from pipeline import INTEGRATORS, load_pipeline_assets, resolve_integrator, sample_batch
13
+
14
+ N_SAMPLES = 5
15
+ MAX_STEPS = 80
16
+ DEFAULT_STEPS = 20
17
+ ROOT_DIR = Path(__file__).parent
18
+ LOGO_PATH = ROOT_DIR / "logo.png"
19
+ LOGO_VALUE = str(LOGO_PATH) if LOGO_PATH.exists() else None
20
+ DEFAULT_CHURN_RATE = 0.0
21
+ DEFAULT_CHURN_MIN = 0.0
22
+ DEFAULT_CHURN_MAX = 0.0
23
+ DEFAULT_NOISE_INFLATION = 1.0
24
+ MAX_NOISE_INFLATION = 1.02
25
+ SUMMARY_PLACEHOLDER_HTML = """
26
+ <div class="summary-card is-empty">
27
+ <div class="summary-title">Ready to sample</div>
28
+ <p>Select an integrator, adjust the controls, then generate digits to inspect their trajectories.</p>
29
+ </div>
30
+ """.strip()
31
+
32
+ CUSTOM_CSS = """
33
+ body {background: radial-gradient(circle at top left, #ffe8d5, #fff7f0 55%, #fdf1f8);}
34
+ #hero {
35
+ display: flex;
36
+ align-items: center;
37
+ justify-content: center;
38
+ gap: 1.5rem;
39
+ background: rgba(255, 255, 255, 0.85);
40
+ padding: 1.5rem 2rem;
41
+ border-radius: 18px;
42
+ box-shadow: 0 18px 35px rgba(255, 135, 0, 0.15);
43
+ border: 1px solid rgba(255, 145, 0, 0.35);
44
+ }
45
+ .hero-logo img {max-width: 320px; width: 100%; object-fit: contain;}
46
+ .hero-copy {font-size: 1.05rem !important; color: #7a3b09;}
47
+ .control-card {
48
+ background: rgba(255, 255, 255, 0.92);
49
+ border-radius: 16px;
50
+ padding: 1.25rem;
51
+ border: 1px solid rgba(255, 166, 77, 0.35);
52
+ box-shadow: 0 14px 30px rgba(255, 140, 0, 0.12);
53
+ }
54
+ .generate-button button {
55
+ background: linear-gradient(135deg, #ff7e00, #ffb347);
56
+ color: #fff;
57
+ font-weight: 600;
58
+ border-radius: 12px;
59
+ box-shadow: 0 10px 20px rgba(255, 126, 0, 0.25);
60
+ }
61
+ .generate-button button:hover {filter: brightness(1.05);}
62
+ .control-heading {
63
+ font-weight: 600;
64
+ color: #7a3b09;
65
+ margin-bottom: 0.6rem !important;
66
+ }
67
+ .plot-card {
68
+ background: rgba(255, 255, 255, 0.88);
69
+ border-radius: 16px;
70
+ padding: 1rem;
71
+ border: 1px solid rgba(255, 166, 77, 0.35);
72
+ box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.4), 0 12px 28px rgba(255, 145, 0, 0.18);
73
+ }
74
+ .details-card {
75
+ border: none;
76
+ padding: 0;
77
+ }
78
+ .summary-card {
79
+ background: rgba(255, 255, 255, 0.9);
80
+ border-radius: 14px;
81
+ padding: 1.1rem 1.25rem;
82
+ border: 1px solid rgba(255, 166, 77, 0.35);
83
+ box-shadow: 0 12px 26px rgba(255, 145, 0, 0.16);
84
+ display: grid;
85
+ gap: 0.85rem;
86
+ }
87
+ .summary-card.is-empty {
88
+ border-style: dashed;
89
+ box-shadow: none;
90
+ }
91
+ .summary-title {
92
+ font-weight: 600;
93
+ font-size: 1.05rem;
94
+ color: #7a3b09;
95
+ }
96
+ .summary-section {
97
+ display: grid;
98
+ gap: 0.45rem;
99
+ }
100
+ .summary-grid {
101
+ display: grid;
102
+ grid-template-columns: repeat(auto-fit, minmax(120px, 1fr));
103
+ gap: 0.4rem;
104
+ }
105
+ .summary-pill {
106
+ background: rgba(255, 245, 233, 0.95);
107
+ border: 1px solid rgba(255, 166, 77, 0.45);
108
+ border-radius: 999px;
109
+ padding: 0.35rem 0.75rem;
110
+ font-size: 0.85rem;
111
+ display: inline-flex;
112
+ align-items: center;
113
+ gap: 0.35rem;
114
+ color: #7a3b09;
115
+ justify-content: center;
116
+ }
117
+ .summary-pill strong {font-weight: 600;}
118
+ .summary-pill.integrator {
119
+ background: rgba(255, 231, 206, 0.95);
120
+ border-color: rgba(255, 160, 72, 0.65);
121
+ font-weight: 600;
122
+ }
123
+ .summary-divider {
124
+ border: none;
125
+ border-top: 1px dashed rgba(255, 166, 77, 0.4);
126
+ margin: 0.2rem 0;
127
+ }
128
+ .accordion-card {
129
+ --tw-border-opacity: 0.45;
130
+ border: 1px dashed rgba(255, 166, 77, 0.45) !important;
131
+ border-radius: 14px !important;
132
+ background: rgba(255, 255, 255, 0.88) !important;
133
+ }
134
+ .accordion-card > div:nth-child(1) {
135
+ font-weight: 600;
136
+ color: #7a3b09;
137
+ }
138
+ .churn-card {
139
+ margin-top: 0.75rem;
140
+ background: rgba(255, 255, 255, 0.85);
141
+ border-radius: 14px;
142
+ padding: 0.9rem 1rem 1.1rem;
143
+ border: 1px dashed rgba(255, 166, 77, 0.5);
144
+ box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.55);
145
+ }
146
+ .churn-title {
147
+ font-size: 0.92rem !important;
148
+ color: #8a450f;
149
+ margin-bottom: 0.55rem !important;
150
+ }
151
+ .gallery-card {
152
+ background: rgba(255, 255, 255, 0.9);
153
+ border-radius: 16px;
154
+ padding: 0.3rem 0.4rem;
155
+ border: 1px solid rgba(255, 166, 77, 0.28);
156
+ box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.25), 0 8px 18px rgba(255, 145, 0, 0.12);
157
+ }
158
+ .gallery-card [data-testid="upload-zone"] {
159
+ display: none !important;
160
+ }
161
+ .gallery-card .grid {
162
+ min-height: 180px;
163
+ }
164
+ .gallery-card img {
165
+ border-radius: 10px;
166
+ transition: transform 0.15s ease, box-shadow 0.15s ease;
167
+ }
168
+ .gallery-card img:hover {
169
+ transform: translateY(-2px);
170
+ box-shadow: 0 8px 14px rgba(255, 145, 0, 0.18);
171
+ }
172
+ .history-card {
173
+ background: rgba(255, 255, 255, 0.88);
174
+ border-radius: 16px;
175
+ padding: 0.9rem;
176
+ border: 1px solid rgba(255, 166, 77, 0.35);
177
+ box-shadow: inset 0 0 0 1px rgba(255, 255, 255, 0.35), 0 10px 22px rgba(255, 145, 0, 0.15);
178
+ }
179
+ .plot-title {
180
+ color: #7a3b09 !important;
181
+ text-align: center;
182
+ font-weight: 600 !important;
183
+ margin-bottom: 0.45rem !important;
184
+ }
185
+ .history-placeholder {
186
+ text-align: center;
187
+ color: #8a450f;
188
+ font-size: 0.9rem;
189
+ margin-top: 0.5rem;
190
+ }
191
+ .value-chip {
192
+ background: rgba(255, 231, 206, 0.9);
193
+ border-radius: 999px;
194
+ padding: 0.1rem 0.55rem;
195
+ font-size: 0.82rem;
196
+ margin-left: 0.4rem;
197
+ color: #84400e;
198
+ }
199
+ @media (max-width: 768px) {
200
+ .summary-grid {grid-template-columns: repeat(auto-fit, minmax(110px, 1fr));}
201
+ .gallery-card .grid {min-height: 150px;}
202
+ .plot-card, .history-card {padding: 0.7rem;}
203
+ }
204
+ """
205
+
206
+
207
+ def _prepare_gallery_images(samples: np.ndarray) -> List[np.ndarray]:
208
+ """Convert normalized grayscale samples to RGB arrays for display."""
209
+ clipped = np.clip(samples, 0.0, 1.0)
210
+ uint8_imgs = (clipped * 255).astype(np.uint8)
211
+ if uint8_imgs.ndim == 3:
212
+ uint8_imgs = uint8_imgs[..., np.newaxis]
213
+ return [np.repeat(img, 3, axis=-1) for img in uint8_imgs]
214
+
215
+
216
+ def _make_history_plot(history_frames: np.ndarray) -> plt.Figure:
217
+ """Render up to 10 frames from a sample trajectory in a single row."""
218
+ if history_frames.ndim == 4 and history_frames.shape[-1] == 1:
219
+ history_frames = history_frames[..., 0]
220
+ total_frames = history_frames.shape[0]
221
+ n_display = min(10, total_frames)
222
+ if n_display < 1:
223
+ raise ValueError("History sequence is empty.")
224
+ indices = np.linspace(0, total_frames - 1, n_display, dtype=int)
225
+ selected = history_frames[indices]
226
+
227
+ fig, axes = plt.subplots(1, n_display, figsize=(2.2 * n_display, 2.2))
228
+ if n_display == 1:
229
+ axes = np.array([axes])
230
+ for idx, ax in enumerate(np.atleast_1d(axes)):
231
+ ax.axis("off")
232
+ ax.imshow(selected[idx], cmap="gray")
233
+ ax.set_title(f"Step {indices[idx] + 1}", fontsize=8, color="#8a450f", pad=6)
234
+ fig.tight_layout()
235
+ return fig
236
+
237
+
238
+ def _format_summary(
239
+ *,
240
+ integrator_label: str,
241
+ n_steps: int,
242
+ history_len: int,
243
+ churn_params: Optional[dict],
244
+ ) -> str:
245
+ sampler_grid = f"""
246
+ <div class="summary-grid">
247
+ <span class="summary-pill integrator">{integrator_label}</span>
248
+ <span class="summary-pill">Steps <strong>{n_steps}</strong></span>
249
+ <span class="summary-pill">Samples <strong>{N_SAMPLES}</strong></span>
250
+ <span class="summary-pill">History <strong>{history_len}</strong></span>
251
+ </div>
252
+ """.strip()
253
+
254
+ churn_block = ""
255
+ if churn_params:
256
+ churn_block = f"""
257
+ <hr class="summary-divider" />
258
+ <div class="summary-section">
259
+ <div class="summary-title">Churning</div>
260
+ <div class="summary-grid">
261
+ <span class="summary-pill">Rate <strong>{churn_params['stochastic_churn_rate']:.3f}</strong></span>
262
+ <span class="summary-pill">Min <strong>{churn_params['churn_min']:.3f}</strong></span>
263
+ <span class="summary-pill">Max <strong>{churn_params['churn_max']:.3f}</strong></span>
264
+ <span class="summary-pill">Inflation <strong>{churn_params['noise_inflation_factor']:.4f}</strong></span>
265
+ </div>
266
+ </div>
267
+ """.strip()
268
+
269
+ return f"""
270
+ <div class="summary-card">
271
+ <div class="summary-section">
272
+ <div class="summary-title">Sampler</div>
273
+ {sampler_grid}
274
+ </div>
275
+ {churn_block}
276
+ </div>
277
+ """.strip()
278
+
279
+
280
+ def show_history(evt: gr.SelectData, histories: Optional[List[np.ndarray]]):
281
+ """Render the trajectory plot for the selected sample."""
282
+ if histories is None or len(histories) == 0:
283
+ return gr.update(value=None, visible=False), gr.update(
284
+ value="Click a digit above to explore its diffusion trajectory.",
285
+ visible=True,
286
+ )
287
+
288
+ index = 0
289
+ if evt is not None and evt.index is not None:
290
+ index = evt.index
291
+ if isinstance(index, (list, tuple)):
292
+ index = index[-1]
293
+
294
+ if not isinstance(index, (int, np.integer)) or index < 0 or index >= len(histories):
295
+ return gr.update(value=None, visible=False), gr.update(
296
+ value="Click a digit above to explore its diffusion trajectory.",
297
+ visible=True,
298
+ )
299
+ if histories[index] is None:
300
+ return gr.update(value=None, visible=False), gr.update(
301
+ value="Click a digit above to explore its diffusion trajectory.",
302
+ visible=True,
303
+ )
304
+ figure = _make_history_plot(histories[index])
305
+ return gr.update(value=figure, visible=True), gr.update(visible=False)
306
+
307
+
308
+ def generate(
309
+ integrator_label: str,
310
+ n_steps: int,
311
+ seed: int,
312
+ enable_churn: bool,
313
+ churn_rate: float,
314
+ churn_min_value: float,
315
+ churn_max_value: float,
316
+ noise_inflation_value: float,
317
+ ):
318
+ """Run sampling with the requested configuration and return UI artifacts."""
319
+ _, integrator_cfg = resolve_integrator(integrator_label)
320
+
321
+ n_steps = int(n_steps)
322
+ seed = int(seed)
323
+
324
+ if not (1 <= n_steps <= MAX_STEPS):
325
+ raise gr.Error(f"Number of steps must be between 1 and {MAX_STEPS}.")
326
+
327
+ supports_churn = integrator_cfg.get("supports_churn", False)
328
+ churn_params = None
329
+
330
+ if enable_churn:
331
+ if not supports_churn:
332
+ raise gr.Error("Stochastic churning is only available for deterministic integrators.")
333
+
334
+ churn_rate = float(churn_rate)
335
+ churn_min_value = float(churn_min_value)
336
+ churn_max_value = float(churn_max_value)
337
+ noise_inflation_value = float(noise_inflation_value)
338
+
339
+ if churn_rate < 0 or churn_rate > 1:
340
+ raise gr.Error("Churn rate must be within [0, 1].")
341
+ if churn_min_value < 0 or churn_max_value < 0:
342
+ raise gr.Error("Churn thresholds must be non-negative.")
343
+ if churn_max_value < churn_min_value:
344
+ raise gr.Error("Churn max threshold must be greater than or equal to churn min threshold.")
345
+ if noise_inflation_value < 1.0 or noise_inflation_value > MAX_NOISE_INFLATION:
346
+ raise gr.Error(f"Noise inflation factor must be within [1.0, {MAX_NOISE_INFLATION}].")
347
+
348
+ churn_params = {
349
+ "stochastic_churn_rate": churn_rate,
350
+ "churn_min": churn_min_value,
351
+ "churn_max": churn_max_value,
352
+ "noise_inflation_factor": noise_inflation_value,
353
+ }
354
+
355
+ denoiser_state, history = sample_batch(
356
+ integrator_label,
357
+ n_steps=n_steps,
358
+ n_samples=N_SAMPLES,
359
+ seed=seed,
360
+ keep_history=True,
361
+ churn_params=churn_params,
362
+ )
363
+
364
+ integrator_state = denoiser_state.integrator_state
365
+ samples = jax.device_get(integrator_state.position)
366
+ samples = np.asarray(samples)
367
+
368
+ if samples.ndim == 4 and samples.shape[-1] == 1:
369
+ samples = samples[..., 0]
370
+
371
+ # Diffusion models typically output data in [-1, 1]. Rescale to [0, 1].
372
+ samples = 0.5 * (samples + 1.0)
373
+ samples = np.clip(samples, 0.0, 1.0)
374
+
375
+ gallery_images = _prepare_gallery_images(samples)
376
+
377
+ sample_histories: Optional[List[np.ndarray]] = None
378
+ if history is not None:
379
+ history_np = jax.device_get(history)
380
+ history_np = np.asarray(history_np)
381
+ history_np = 0.5 * (history_np + 1.0)
382
+ history_np = np.clip(history_np, 0.0, 1.0)
383
+ sample_histories = [
384
+ history_np[:, sample_idx]
385
+ for sample_idx in range(history_np.shape[1])
386
+ ]
387
+ if sample_histories is None:
388
+ sample_histories = []
389
+
390
+ history_len = int(history.shape[0]) if history is not None else 0
391
+
392
+ summary_html = _format_summary(
393
+ integrator_label=integrator_cfg["label"],
394
+ n_steps=n_steps,
395
+ history_len=history_len,
396
+ churn_params=churn_params,
397
+ )
398
+
399
+ gallery_update = gr.update(
400
+ value=gallery_images,
401
+ visible=True,
402
+ interactive=True,
403
+ height=220,
404
+ )
405
+ summary_update = gr.update(value=summary_html)
406
+ history_reset = gr.update(value=None, visible=False)
407
+ placeholder_update = gr.update(
408
+ value="Click a digit above to explore its diffusion trajectory.",
409
+ visible=True,
410
+ )
411
+
412
+ gr.Info(
413
+ f"Generated {N_SAMPLES} samples with {integrator_cfg['label']} ({n_steps} steps).",
414
+ duration=3,
415
+ )
416
+
417
+ return gallery_update, summary_update, history_reset, placeholder_update, sample_histories
418
+
419
+
420
+ def _handle_churn_toggle(integrator_label: str, enable_churn: bool):
421
+ """Toggle churn controls visibility/open state based on integrator support."""
422
+ _, integrator_cfg = resolve_integrator(integrator_label)
423
+ supports = integrator_cfg.get("supports_churn", False)
424
+ enable_effective = supports and enable_churn
425
+ column_update = gr.update(visible=enable_effective)
426
+ accordion_update = gr.update(open=enable_effective)
427
+ return column_update, accordion_update
428
+
429
+
430
+ def _handle_integrator_change(integrator_label: str, enable_churn: bool):
431
+ """Adjust checkbox interactivity and churn panel visibility when integrator changes."""
432
+ _, integrator_cfg = resolve_integrator(integrator_label)
433
+ supports = integrator_cfg.get("supports_churn", False)
434
+ effective_enable = enable_churn if supports else False
435
+ checkbox_update = gr.update(
436
+ interactive=supports,
437
+ value=effective_enable,
438
+ )
439
+ column_update, accordion_update = _handle_churn_toggle(integrator_label, effective_enable)
440
+ return checkbox_update, column_update, accordion_update
441
+
442
+
443
+ def _sync_churn_max(churn_min_value: float, current_max_value: float):
444
+ """Ensure churn_max stays >= churn_min when churn_min changes."""
445
+ churn_min_value = float(churn_min_value)
446
+ current_max_value = float(current_max_value)
447
+ adjusted_max = current_max_value if current_max_value >= churn_min_value else churn_min_value
448
+ return gr.update(value=adjusted_max)
449
+
450
+
451
+ def _sync_churn_min(churn_max_value: float, current_min_value: float):
452
+ """Ensure churn_min stays <= churn_max when churn_max changes."""
453
+ churn_max_value = float(churn_max_value)
454
+ current_min_value = float(current_min_value)
455
+ adjusted_min = current_min_value if current_min_value <= churn_max_value else churn_max_value
456
+ return gr.update(value=adjusted_min)
457
+
458
+
459
+ def build_ui() -> gr.Blocks:
460
+ """Create the Gradio Blocks interface."""
461
+ available_labels = [spec["label"] for spec in INTEGRATORS.values()]
462
+ default_label = INTEGRATORS["ddim"]["label"]
463
+
464
+ with gr.Blocks(
465
+ title="Diffuse Integrator Explorer",
466
+ css=CUSTOM_CSS,
467
+ theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange"),
468
+ ) as demo:
469
+ with gr.Row(elem_id="hero"):
470
+ gr.Image(
471
+ value=LOGO_VALUE,
472
+ show_label=False,
473
+ interactive=False,
474
+ elem_classes="hero-logo",
475
+ )
476
+ gr.Markdown(
477
+ """
478
+ ### Diffuse Integrator Explorer
479
+ Experiment with deterministic or stochastic samplers from the
480
+ `diffuse-jax` library. Adjust the number of diffusion steps,
481
+ hit **Generate Samples**, and compare the five digits rendered
482
+ in the panel on the right.
483
+ """.strip(),
484
+ elem_classes="hero-copy",
485
+ )
486
+
487
+ with gr.Row():
488
+ with gr.Column(elem_classes="control-card"):
489
+ gr.Markdown("#### Sampling Controls", elem_classes="control-heading")
490
+ integrator_input = gr.Dropdown(
491
+ choices=available_labels,
492
+ value=default_label,
493
+ label="Integrator",
494
+ )
495
+ steps_input = gr.Slider(
496
+ minimum=1,
497
+ maximum=MAX_STEPS,
498
+ value=DEFAULT_STEPS,
499
+ step=1,
500
+ label="Number of steps",
501
+ )
502
+ seed_input = gr.Number(
503
+ value=0,
504
+ precision=0,
505
+ label="Random seed",
506
+ info="Use a different seed to explore new digits.",
507
+ )
508
+ with gr.Accordion("Churning controls", open=False, elem_classes="accordion-card") as churn_accordion:
509
+ churn_checkbox = gr.Checkbox(
510
+ value=False,
511
+ label="Enable stochastic churning",
512
+ info="Add controlled noise for deterministic integrators.",
513
+ )
514
+ with gr.Column(visible=False, elem_classes="churn-card") as churn_column:
515
+ gr.Markdown(
516
+ "**Churning parameters** · tweak how strongly noise is injected during sampling.",
517
+ elem_classes="churn-title",
518
+ )
519
+ churn_rate_input = gr.Slider(
520
+ minimum=0.0,
521
+ maximum=1.0,
522
+ value=DEFAULT_CHURN_RATE,
523
+ step=0.01,
524
+ label="Churn rate",
525
+ )
526
+ churn_min_input = gr.Slider(
527
+ minimum=0.0,
528
+ maximum=1.0,
529
+ value=DEFAULT_CHURN_MIN,
530
+ step=0.01,
531
+ label="Churn min threshold",
532
+ )
533
+ churn_max_input = gr.Slider(
534
+ minimum=0.0,
535
+ maximum=1.0,
536
+ value=DEFAULT_CHURN_MAX,
537
+ step=0.01,
538
+ label="Churn max threshold",
539
+ )
540
+ noise_inflation_input = gr.Slider(
541
+ minimum=1.0,
542
+ maximum=MAX_NOISE_INFLATION,
543
+ value=DEFAULT_NOISE_INFLATION,
544
+ step=0.001,
545
+ label="Noise inflation factor",
546
+ )
547
+ generate_button = gr.Button(
548
+ "Generate Samples",
549
+ variant="primary",
550
+ elem_classes="generate-button",
551
+ )
552
+
553
+ with gr.Column():
554
+ details = gr.HTML(
555
+ SUMMARY_PLACEHOLDER_HTML,
556
+ elem_classes="details-card",
557
+ container=False,
558
+ )
559
+ gr.Markdown("#### Generated Digit Strip", elem_classes="plot-title")
560
+ digit_strip = gr.Gallery(
561
+ columns=5,
562
+ allow_preview=False,
563
+ show_fullscreen_button=False,
564
+ object_fit="contain",
565
+ rows=1,
566
+ height=220,
567
+ show_label=False,
568
+ interactive=True,
569
+ elem_classes="gallery-card",
570
+ value=[],
571
+ container=False,
572
+ visible=False,
573
+ )
574
+ gr.Markdown("#### Sample Trajectory", elem_classes="plot-title")
575
+ history_plot = gr.Plot(elem_classes="history-card", show_label=False, visible=False)
576
+ history_placeholder = gr.Markdown(
577
+ "Generate samples, then click a digit above to explore its diffusion trajectory.",
578
+ elem_classes="history-placeholder",
579
+ visible=True,
580
+ container=False,
581
+ )
582
+
583
+ histories_state = gr.State([])
584
+
585
+ integrator_input.change(
586
+ fn=_handle_integrator_change,
587
+ inputs=[integrator_input, churn_checkbox],
588
+ outputs=[churn_checkbox, churn_column, churn_accordion],
589
+ )
590
+ churn_checkbox.change(
591
+ fn=_handle_churn_toggle,
592
+ inputs=[integrator_input, churn_checkbox],
593
+ outputs=[churn_column, churn_accordion],
594
+ )
595
+ churn_min_input.change(
596
+ fn=_sync_churn_max,
597
+ inputs=[churn_min_input, churn_max_input],
598
+ outputs=churn_max_input,
599
+ )
600
+ churn_max_input.change(
601
+ fn=_sync_churn_min,
602
+ inputs=[churn_max_input, churn_min_input],
603
+ outputs=churn_min_input,
604
+ )
605
+ generate_button.click(
606
+ fn=generate,
607
+ inputs=[
608
+ integrator_input,
609
+ steps_input,
610
+ seed_input,
611
+ churn_checkbox,
612
+ churn_rate_input,
613
+ churn_min_input,
614
+ churn_max_input,
615
+ noise_inflation_input,
616
+ ],
617
+ outputs=[digit_strip, details, history_plot, history_placeholder, histories_state],
618
+ )
619
+
620
+ digit_strip.select(
621
+ fn=show_history,
622
+ inputs=[histories_state],
623
+ outputs=[history_plot, history_placeholder],
624
+ )
625
+
626
+ return demo
627
+
628
+
629
+ load_pipeline_assets()
630
+
631
+
632
+ if __name__ == "__main__":
633
+ demo = build_ui()
634
+ demo.queue().launch()
pipeline.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import importlib.util
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Optional, Tuple
5
+
6
+ import jax
7
+ from flax import nnx, serialization
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from diffuse.diffusion.sde import Flow
11
+ from diffuse.integrator.deterministic import DDIMIntegrator, DPMpp2sIntegrator, EulerIntegrator, HeunIntegrator
12
+ from diffuse.integrator.stochastic import EulerMaruyamaIntegrator
13
+ from diffuse.predictor import Predictor
14
+ from diffuse.timer.base import VpTimer
15
+ from diffuse.denoisers.denoiser import Denoiser
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class PipelineAssets:
20
+ """Container holding the preloaded model artifacts."""
21
+
22
+ model: Any
23
+ flow: Flow
24
+ predictor: Predictor
25
+ x0_shape: Tuple[int, int, int]
26
+
27
+
28
+ @functools.lru_cache(maxsize=1)
29
+ def load_pipeline_assets() -> PipelineAssets:
30
+ """Download the HF model and build the predictor stack once."""
31
+ model_path = hf_hub_download(repo_id="jcopo/mnist", filename="model.msgpack")
32
+ config_path = hf_hub_download(repo_id="jcopo/mnist", filename="config.py")
33
+
34
+ spec = importlib.util.spec_from_file_location("model_config", config_path)
35
+ if spec is None or spec.loader is None:
36
+ raise RuntimeError("Unable to load model config from Hugging Face hub.")
37
+ config_module = importlib.util.module_from_spec(spec)
38
+ spec.loader.exec_module(config_module)
39
+
40
+ model = config_module.model
41
+
42
+ with open(model_path, "rb") as f:
43
+ state_dict = serialization.from_bytes(None, f.read())
44
+
45
+ graphdef, state = nnx.split(model)
46
+ state.replace_by_pure_dict(state_dict)
47
+ model = nnx.merge(graphdef, state)
48
+ model.eval()
49
+
50
+ flow = Flow(tf=1.0)
51
+ predictor = Predictor(
52
+ model=flow,
53
+ network=lambda x, t: model(x, t).output,
54
+ prediction_type="velocity",
55
+ )
56
+
57
+ return PipelineAssets(
58
+ model=model,
59
+ flow=flow,
60
+ predictor=predictor,
61
+ x0_shape=(28, 28, 1),
62
+ )
63
+
64
+
65
+ INTEGRATORS: Dict[str, Dict[str, Any]] = {
66
+ "ddim": {
67
+ "label": "DDIM (Deterministic)",
68
+ "cls": DDIMIntegrator,
69
+ "description": "Deterministic DDIM sampler.",
70
+ "supports_churn": True,
71
+ },
72
+ "heun": {
73
+ "label": "Heun (Deterministic 2nd order)",
74
+ "cls": HeunIntegrator,
75
+ "description": "Second-order deterministic integrator.",
76
+ "supports_churn": True,
77
+ },
78
+ "euler": {
79
+ "label": "Euler (Deterministic)",
80
+ "cls": EulerIntegrator,
81
+ "description": "Forward Euler integrator.",
82
+ "supports_churn": True,
83
+ },
84
+ "dpmpp2s": {
85
+ "label": "DPM++ 2S (Deterministic multi-step)",
86
+ "cls": DPMpp2sIntegrator,
87
+ "description": "Deterministic multi-step sampler with second-order accuracy.",
88
+ "supports_churn": True,
89
+ },
90
+ "euler_maruyama": {
91
+ "label": "Euler-Maruyama (Stochastic)",
92
+ "cls": EulerMaruyamaIntegrator,
93
+ "description": "Stochastic sampler with noise at each diffusion step.",
94
+ "supports_churn": False,
95
+ },
96
+ }
97
+
98
+ LABEL_TO_KEY = {spec["label"]: key for key, spec in INTEGRATORS.items()}
99
+
100
+
101
+ def resolve_integrator(identifier: str) -> Tuple[str, Dict[str, Any]]:
102
+ """Resolve either an integrator key or display label to the configuration dict."""
103
+ if identifier in INTEGRATORS:
104
+ return identifier, INTEGRATORS[identifier]
105
+ if identifier in LABEL_TO_KEY:
106
+ key = LABEL_TO_KEY[identifier]
107
+ return key, INTEGRATORS[key]
108
+ raise KeyError(f"Unknown integrator identifier: {identifier}")
109
+
110
+
111
+ def build_denoiser(
112
+ integrator_key: str,
113
+ n_steps: int,
114
+ *,
115
+ churn_params: Optional[Dict[str, float]] = None,
116
+ ) -> Denoiser:
117
+ """Instantiate a denoiser wired with the requested integrator and timer."""
118
+ if n_steps < 1:
119
+ raise ValueError("n_steps must be >= 1")
120
+
121
+ assets = load_pipeline_assets()
122
+ _, integrator_cfg = resolve_integrator(integrator_key)
123
+
124
+ timer = VpTimer(n_steps=n_steps, eps=0.001, tf=1.0)
125
+ integrator_kwargs: Dict[str, float] = {}
126
+ if churn_params:
127
+ if not integrator_cfg.get("supports_churn", False):
128
+ raise ValueError(f"Integrator '{integrator_cfg['label']}' does not support stochastic churning.")
129
+ integrator_kwargs = churn_params
130
+
131
+ integrator = integrator_cfg["cls"](model=assets.flow, timer=timer, **integrator_kwargs)
132
+ return Denoiser(
133
+ integrator=integrator,
134
+ model=assets.flow,
135
+ predictor=assets.predictor,
136
+ x0_shape=assets.x0_shape,
137
+ )
138
+
139
+
140
+ def sample_batch(
141
+ integrator_identifier: str,
142
+ *,
143
+ n_steps: int,
144
+ n_samples: int,
145
+ seed: int,
146
+ keep_history: bool = False,
147
+ churn_params: Optional[Dict[str, float]] = None,
148
+ ):
149
+ """Generate a batch of samples for the requested integrator."""
150
+ if n_samples < 1:
151
+ raise ValueError("n_samples must be >= 1")
152
+
153
+ denoiser = build_denoiser(integrator_identifier, n_steps, churn_params=churn_params)
154
+ key = jax.random.PRNGKey(seed)
155
+
156
+ # The denoiser expects the number of steps to match the timer configuration.
157
+ state, history = denoiser.generate(key, n_steps, n_samples, keep_history=keep_history)
158
+ return state, history
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ diffuse-jax
3
+ jax[cpu]
4
+ flax<0.12
5
+ huggingface_hub
6
+ matplotlib
7
+ numpy
script.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import load_pipeline_assets, sample_batch
2
+
3
+
4
+ def main() -> None:
5
+ """Example script showing how to invoke the diffusion pipeline."""
6
+ load_pipeline_assets()
7
+ print("✅ Model loaded successfully!")
8
+
9
+ denoiser_state, history = sample_batch(
10
+ "ddim",
11
+ n_steps=10,
12
+ n_samples=5,
13
+ seed=456,
14
+ keep_history=True,
15
+ )
16
+
17
+ print(f"state.position shape: {denoiser_state.integrator_state.position.shape}")
18
+ print(f"History length: {len(history)}")
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()