import torch import gradio as gr import spaces from model import generate_grid_image, load_model MODEL_READY = False def ensure_model_loaded(): global MODEL_READY if not MODEL_READY: load_model() MODEL_READY = True @spaces.GPU @torch.inference_mode() def predict(label: int, steps: int, num_samples: int): ensure_model_loaded() return generate_grid_image(label=label, steps=steps, num_samples=num_samples) with gr.Blocks(title="MNIST Diffusion") as demo: gr.Markdown("# MNIST Diffusion") gr.Markdown( "Discrete diffusion model for MNIST digits. " "Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99." ) grid = gr.Image(label="Samples", show_label=True) with gr.Row(): label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label") steps = gr.Slider(1, 784, value=784, step=1, label="Steps") num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples") generate_btn = gr.Button("Generate") generate_btn.click( fn=predict, inputs=[label, steps, num_samples], outputs=grid, scroll_to_output=True, ) if __name__ == "__main__": demo.launch()