Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import torchvision | |
| from diffusers import UNet2DModel, UNet2DConditionModel, DDPMScheduler, DDPMPipeline, DDIMScheduler | |
| from fastprogress.fastprogress import progress_bar | |
| labels_map = { | |
| 0: "T-Shirt", | |
| 1: "Trouser", | |
| 2: "Pullover", | |
| 3: "Dress", | |
| 4: "Coat", | |
| 5: "Sandal", | |
| 6: "Shirt", | |
| 7: "Sneaker", | |
| 8: "Bag", | |
| 9: "Ankle Boot", | |
| } | |
| l2i = {l:i for i,l in labels_map.items()} | |
| def label2idx(l): | |
| return l2i[l] | |
| unet = torch.load("unconditional01.pt", map_location=torch.device('cpu')).to("cpu") | |
| Emb = torch.load("unconditional_emb_01.pt", map_location=torch.device('cpu')).to("cpu") | |
| unet.eval() | |
| sched = DDIMScheduler(beta_end=0.01) | |
| sched.set_timesteps(20) | |
| def diff_sample(model, sz, sched, hidden, **kwargs): | |
| x_t = torch.randn(sz) | |
| preds = [] | |
| for t in progress_bar(sched.timesteps): | |
| with torch.no_grad(): noise = model(x_t, t, hidden).sample | |
| x_t = sched.step(noise, t, x_t, **kwargs).prev_sample | |
| preds.append(x_t.float().cpu()) | |
| return preds | |
| def generate(classChoice): | |
| sz = (1,1,32,32) | |
| print(classChoice) | |
| hidden = Emb(torch.tensor([label2idx(classChoice)]*1)[:,None]).detach().to("cpu") | |
| preds = diff_sample(unet, sz, sched, hidden, eta=1.) | |
| return((preds[-1][0] + 0.5).squeeze().clamp(-1,1).detach().numpy()) | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">Conditional Diffusion with DDIM</h1>""") | |
| gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") | |
| session_data = gr.State([]) | |
| classChoice = gr.Radio(list(labels_map.values()), value="T-Shirt", label="Select the type of image to generate", info="") | |
| sampling_button = gr.Button("Conditional image generation") | |
| final_image = gr.Image(height=250,width=200) | |
| sampling_button.click( | |
| generate, | |
| [classChoice], | |
| [final_image], | |
| ) | |
| demo.queue().launch(share=False, inbrowser=True) | |