import gradio as gr import torch from torch import nn import torchvision from torchvision.transforms import ToTensor from types import SimpleNamespace import matplotlib.pyplot as plt unet = torch.load("unet_checkpoint13.pt", map_location=torch.device('cpu')).to("cpu") unet.eval() @torch.no_grad def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000): beta = torch.linspace(betamin, betamax, n_steps) return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt()) n_steps = 1000 lin_abar = linear_sched(betamax=0.01) alphabar = lin_abar.abar alpha = lin_abar.a sigma = lin_abar.sig @torch.no_grad() def generate(): model = unet sz = (1, 1, 32, 32) ps = next(model.parameters()) x_t = torch.randn(sz).to(ps) sample_at = {t for t in range(n_steps) if (t+101)%((t+101)//100)==0} preds = [] img_final = img_799 = img_399 = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy() for t in reversed(range(n_steps)): t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long) z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps) ᾱ_t1 = alphabar[t-1] if t > 0 else torch.tensor(1) b̄_t = 1-alphabar[t] b̄_t1 = 1-ᾱ_t1 if t in sample_at: noise = model(x_t, t_batch).sample x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt()) x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z if t in sample_at: preds.append(x_t.float().cpu()) img = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy() if t >= 799: img_final = img_799 = img_399 = img elif t >= 50: img_final = img_399 = img else: img_final = img yield(img_799,img_399,img_final) with gr.Blocks() as demo: gr.HTML("""

UNet with DPPM

""") gr.HTML("""

trained with FashionMNIST

""") session_data = gr.State([]) sampling_button = gr.Button("Unconditional image generation") with gr.Row(): with gr.Column(scale=2): gr.HTML("""

image at step 800

""") step_800_image = gr.Image(height=250,width=200) with gr.Column(scale=2): gr.HTML("""

image at step 50

""") step_50_image = gr.Image(height=250,width=200) with gr.Column(scale=2): gr.HTML("""

final image

""") step_final_image = gr.Image(height=250,width=200) sampling_button.click( generate, [], [step_800_image, step_50_image, step_final_image], ) demo.queue().launch(share=False, inbrowser=True)