wb-droid commited on
Commit
833dd01
·
1 Parent(s): 8de7c1e

initial commit.

Browse files
Files changed (3) hide show
  1. app.py +80 -0
  2. requirements.txt +2 -0
  3. unet_checkpoint13.pt +3 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import torchvision
5
+ from torchvision.transforms import ToTensor
6
+ from types import SimpleNamespace
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ unet = torch.load("unet_checkpoint13.pt", map_location=torch.device('cpu')).to("cpu")
11
+ unet.eval()
12
+
13
+ @torch.no_grad
14
+ def linear_sched(betamin=0.0001,betamax=0.02,n_steps=1000):
15
+ beta = torch.linspace(betamin, betamax, n_steps)
16
+ return SimpleNamespace(a=1.-beta, abar=(1.-beta).cumprod(dim=0), sig=beta.sqrt())
17
+ n_steps = 1000
18
+ lin_abar = linear_sched(betamax=0.01)
19
+ alphabar = lin_abar.abar
20
+ alpha = lin_abar.a
21
+ sigma = lin_abar.sig
22
+
23
+ @torch.no_grad()
24
+ def generate():
25
+ model = unet
26
+ sz = (1, 1, 32, 32)
27
+
28
+ ps = next(model.parameters())
29
+ x_t = torch.randn(sz).to(ps)
30
+ sample_at = {t for t in range(n_steps) if (t+101)%((t+101)//100)==0}
31
+ preds = []
32
+ img_final = img_799 = img_399 = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy()
33
+ for t in reversed(range(n_steps)):
34
+ t_batch = torch.full((x_t.shape[0],), t, device=ps.device, dtype=torch.long)
35
+ z = (torch.randn(x_t.shape) if t > 0 else torch.zeros(x_t.shape)).to(ps)
36
+ ᾱ_t1 = alphabar[t-1] if t > 0 else torch.tensor(1)
37
+ b̄_t = 1-alphabar[t]
38
+ b̄_t1 = 1-ᾱ_t1
39
+ if t in sample_at: noise = model(x_t, t_batch).sample
40
+ x_0_hat = ((x_t - b̄_t.sqrt() * noise)/alphabar[t].sqrt())
41
+ x_t = x_0_hat * ᾱ_t1.sqrt()*(1-alpha[t])/b̄_t + x_t * alpha[t].sqrt()*b̄_t1/b̄_t + sigma[t]*z
42
+ if t in sample_at:
43
+ preds.append(x_t.float().cpu())
44
+ img = (x_t[0].float().cpu()+0.5).squeeze().clamp(-1,1).detach().numpy()
45
+
46
+ if t >= 799:
47
+ img_final = img_799 = img_399 = img
48
+ elif t >= 50:
49
+ img_final = img_399 = img
50
+ else:
51
+ img_final = img
52
+
53
+ yield(img_799,img_399,img_final)
54
+
55
+ with gr.Blocks() as demo:
56
+ gr.HTML("""<h1 align="center">UNet with DPPM</h1>""")
57
+ gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
58
+ session_data = gr.State([])
59
+
60
+ sampling_button = gr.Button("Unconditional image generation")
61
+
62
+ with gr.Row():
63
+ with gr.Column(scale=2):
64
+ gr.HTML("""<h3 align="left">image at step 800</h1>""")
65
+ step_800_image = gr.Image(height=250,width=200)
66
+ with gr.Column(scale=2):
67
+ gr.HTML("""<h3 align="left">image at step 50</h1>""")
68
+ step_50_image = gr.Image(height=250,width=200)
69
+ with gr.Column(scale=2):
70
+ gr.HTML("""<h3 align="left">final image</h1>""")
71
+ step_final_image = gr.Image(height=250,width=200)
72
+
73
+
74
+ sampling_button.click(
75
+ generate,
76
+ [],
77
+ [step_800_image, step_50_image, step_final_image],
78
+ )
79
+
80
+ demo.queue().launch(share=False, inbrowser=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio
2
+ torchvision
unet_checkpoint13.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22929146d48ed7a302cb8ac6166dd7c4ff3dd36a686441a8be01caaa59c786c8
3
+ size 27239114