Spaces:
Sleeping
Sleeping
| import torch | |
| import gradio as gr | |
| import spaces | |
| from model import generate_grid_image, load_model | |
| MODEL_READY = False | |
| def ensure_model_loaded(): | |
| global MODEL_READY | |
| if not MODEL_READY: | |
| load_model() | |
| MODEL_READY = True | |
| def predict(label: int, steps: int, num_samples: int): | |
| ensure_model_loaded() | |
| return generate_grid_image(label=label, steps=steps, num_samples=num_samples) | |
| with gr.Blocks(title="MNIST Diffusion") as demo: | |
| gr.Markdown("# MNIST Diffusion") | |
| gr.Markdown( | |
| "Discrete diffusion model for MNIST digits. " | |
| "Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99." | |
| ) | |
| grid = gr.Image(label="Samples", show_label=True) | |
| with gr.Row(): | |
| label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label") | |
| steps = gr.Slider(1, 784, value=784, step=1, label="Steps") | |
| num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples") | |
| generate_btn = gr.Button("Generate") | |
| generate_btn.click( | |
| fn=predict, | |
| inputs=[label, steps, num_samples], | |
| outputs=grid, | |
| scroll_to_output=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |