mnist-diff-demo / app.py
trixyL
exp: different init label value
e0e7f09
import torch
import gradio as gr
import spaces
from model import generate_grid_image, load_model
MODEL_READY = False
def ensure_model_loaded():
global MODEL_READY
if not MODEL_READY:
load_model()
MODEL_READY = True
@spaces.GPU
@torch.inference_mode()
def predict(label: int, steps: int, num_samples: int):
ensure_model_loaded()
return generate_grid_image(label=label, steps=steps, num_samples=num_samples)
with gr.Blocks(title="MNIST Diffusion") as demo:
gr.Markdown("# MNIST Diffusion")
gr.Markdown(
"Discrete diffusion model for MNIST digits. "
"Sampling uses fixed CFG=2.0, temperature=0.6, top_p=0.99."
)
grid = gr.Image(label="Samples", show_label=True)
with gr.Row():
label = gr.Dropdown([str(i) for i in range(10)], value="4", label="Label")
steps = gr.Slider(1, 784, value=784, step=1, label="Steps")
num_samples = gr.Slider(1, 36, value=16, step=1, label="Samples")
generate_btn = gr.Button("Generate")
generate_btn.click(
fn=predict,
inputs=[label, steps, num_samples],
outputs=grid,
scroll_to_output=True,
)
if __name__ == "__main__":
demo.launch()