Spaces:
Sleeping
Sleeping
File size: 1,387 Bytes
3617dac 776256b 3617dac 776256b 3617dac 776256b 3617dac 776256b 3617dac 776256b 3617dac e0e7f09 776256b 3617dac 776256b ef548bf 3617dac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import torch
import gradio as gr
import spaces
from model import iter_trajectory_frames, load_model
MODEL_READY = False
def ensure_model_loaded():
global MODEL_READY
if not MODEL_READY:
load_model()
MODEL_READY = True
@spaces.GPU
@torch.inference_mode()
def predict(label: int, steps: int):
ensure_model_loaded()
for idx, (image, step_idx, total_steps, total) in enumerate(
iter_trajectory_frames(label=label, steps=steps), start=1
):
yield image, f"trajectory checkpoint {idx}/{total} | denoising step {step_idx}/{total_steps}"
with gr.Blocks(title="MNIST Diffusion") as demo:
gr.Markdown("# MNIST Diffusion")
gr.Markdown(
"Discrete diffusion model for MNIST digits. "
"The demo streams one sample as masked tokens are resolved with fixed CFG=2.0, "
"temperature=0.6, and top_p=0.99."
)
grid = gr.Image(label="Trajectory", show_label=True)
status = gr.Textbox(label="Status")
with gr.Row():
label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
steps = gr.Slider(32, 784, value=784, step=1, label="Steps")
generate_btn = gr.Button("Generate")
generate_btn.click(
fn=predict,
inputs=[label, steps],
outputs=[grid, status],
scroll_to_output=True,
)
if __name__ == "__main__":
demo.launch()
|