wb-droid commited on
Commit
b5c4f7a
·
1 Parent(s): 2623128

initial commit.

Browse files
Files changed (4) hide show
  1. app.py +71 -0
  2. requirements.txt +4 -0
  3. unconditional01.pt +3 -0
  4. unconditional_emb_01.pt +3 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import torchvision
5
+ from diffusers import UNet2DModel, UNet2DConditionModel, DDPMScheduler, DDPMPipeline, DDIMScheduler
6
+ from fastprogress.fastprogress import progress_bar
7
+
8
+ labels_map = {
9
+ 0: "T-Shirt",
10
+ 1: "Trouser",
11
+ 2: "Pullover",
12
+ 3: "Dress",
13
+ 4: "Coat",
14
+ 5: "Sandal",
15
+ 6: "Shirt",
16
+ 7: "Sneaker",
17
+ 8: "Bag",
18
+ 9: "Ankle Boot",
19
+ }
20
+
21
+ l2i = {l:i for i,l in labels_map.items()}
22
+
23
+ def label2idx(l):
24
+ return l2i[l]
25
+
26
+
27
+ unet = torch.load("unconditional01.pt", map_location=torch.device('cpu')).to("cpu")
28
+ Emb = torch.load("unconditional_emb_01.pt", map_location=torch.device('cpu')).to("cpu")
29
+ unet.eval()
30
+
31
+ sched = DDIMScheduler(beta_end=0.01)
32
+ sched.set_timesteps(20)
33
+
34
+ @torch.no_grad
35
+ def diff_sample(model, sz, sched, hidden, **kwargs):
36
+ x_t = torch.randn(sz)
37
+ preds = []
38
+ for t in progress_bar(sched.timesteps):
39
+ with torch.no_grad(): noise = model(x_t, t, hidden).sample
40
+ x_t = sched.step(noise, t, x_t, **kwargs).prev_sample
41
+ preds.append(x_t.float().cpu())
42
+ return preds
43
+
44
+
45
+ @torch.no_grad()
46
+ def generate(classChoice):
47
+ sz = (1,1,32,32)
48
+ print(classChoice)
49
+ hidden = Emb(torch.tensor([label2idx(classChoice)]*1)[:,None]).detach().to("cpu")
50
+ preds = diff_sample(unet, sz, sched, hidden, eta=1.)
51
+
52
+ return((preds[-1][0] + 0.5).squeeze().clamp(-1,1).detach().numpy())
53
+
54
+ with gr.Blocks() as demo:
55
+ gr.HTML("""<h1 align="center">Conditional Diffusion with DDIM</h1>""")
56
+ gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
57
+ session_data = gr.State([])
58
+
59
+ classChoice = gr.Radio(list(labels_map.values()), value="T-Shirt", label="Select the type of image to generate", info="")
60
+ sampling_button = gr.Button("Conditional image generation")
61
+ final_image = gr.Image(height=250,width=200)
62
+
63
+
64
+
65
+ sampling_button.click(
66
+ generate,
67
+ [classChoice],
68
+ [final_image],
69
+ )
70
+
71
+ demo.queue().launch(share=False, inbrowser=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torchvision
3
+ diffusers
4
+ fastprogress
unconditional01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd4dfcc2296cd3f28ed358bceecb1e53ebc541162ce4c88fb510735f229e32e4
3
+ size 15277734
unconditional_emb_01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51a515855b4c2e1490e9f9be7c89e23f61da5ffdd496ad23e2a99e633aa02578
3
+ size 53277