File size: 2,856 Bytes
833dd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor
from types import SimpleNamespace
import matplotlib.pyplot as plt


unet = torch.load("unet_checkpoint13.pt", map_location=torch.device('cpu')).to("cpu")
unet.eval()

@torch.no_grad
def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000):
    beta = torch.linspace(betamin, betamax, n_steps)
    return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt())
n_steps = 1000
lin_abar = linear_sched(betamax=0.01)
alphabar = lin_abar.abar
alpha = lin_abar.a
sigma = lin_abar.sig

@torch.no_grad()  
def generate():
    model = unet
    sz = (1, 1, 32, 32)

    ps = next(model.parameters())
    x_t = torch.randn(sz).to(ps)
    sample_at = {t for t in range(n_steps) if (t+101)%((t+101)//100)==0}
    preds = []
    img_final = img_799 = img_399 = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy()
    for t in reversed(range(n_steps)):
        t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long)
        z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps)
        ᾱ_t1 = alphabar[t-1]  if t > 0 else torch.tensor(1)
        b̄_t = 1-alphabar[t]
        b̄_t1 = 1-ᾱ_t1
        if t in sample_at: noise = model(x_t, t_batch).sample
        x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt())
        x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
        if t in sample_at: 
            preds.append(x_t.float().cpu())
            img = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy()

            if t >= 799: 
                img_final = img_799 = img_399 = img
            elif t >= 50: 
                img_final = img_399 = img
            else:
                img_final = img

            yield(img_799,img_399,img_final)
    
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">UNet with DPPM</h1>""")
    gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
    session_data = gr.State([])

    sampling_button = gr.Button("Unconditional image generation")

    with gr.Row():
        with gr.Column(scale=2):
            gr.HTML("""<h3 align="left">image at step 800</h1>""")
            step_800_image = gr.Image(height=250,width=200)             
        with gr.Column(scale=2):
            gr.HTML("""<h3 align="left">image at step 50</h1>""")
            step_50_image = gr.Image(height=250,width=200) 
        with gr.Column(scale=2):
            gr.HTML("""<h3 align="left">final image</h1>""")
            step_final_image = gr.Image(height=250,width=200) 
     

    sampling_button.click(
        generate,
        [],
        [step_800_image, step_50_image, step_final_image],
    )

demo.queue().launch(share=False, inbrowser=True)