wb-droid commited on
Commit
56258fe
·
1 Parent(s): 201f31e

initial commit.

Browse files
Files changed (3) hide show
  1. app.py +201 -0
  2. myclip.pt +3 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ import torchvision
5
+ from torchvision.transforms import ToTensor
6
+ from types import SimpleNamespace
7
+ import matplotlib.pyplot as plt
8
+ from torchvision import transforms
9
+ from torchvision.transforms import ToTensor, Pad
10
+
11
+ class MyVAE(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.encoder = nn.Sequential(
15
+ # (conv_in)
16
+ nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28
17
+
18
+ # (down_block_0)
19
+ # (norm1)
20
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
21
+ # (conv1)
22
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
23
+ # (norm2):
24
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
25
+ # (dropout):
26
+ nn.Dropout(p=0.5, inplace=False),
27
+ # (conv2):
28
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
29
+ # (nonlinearity):
30
+ nn.SiLU(),
31
+ # (downsamplers)(conv):
32
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14
33
+
34
+
35
+
36
+ # (down_block_1)
37
+ # (norm1)
38
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
39
+ # (conv1)
40
+ nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
41
+ # (norm2):
42
+ nn.GroupNorm(8, 64, eps=1e-06, affine=True),
43
+ # (dropout):
44
+ nn.Dropout(p=0.5, inplace=False),
45
+ # (conv2):
46
+ nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
47
+ # (nonlinearity):
48
+ nn.SiLU(),
49
+ # (conv_shortcut):
50
+ #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
51
+ # (nonlinearity):
52
+ nn.SiLU(),
53
+ # (downsamplers)(conv):
54
+ nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7
55
+
56
+ # (conv_norm_out):
57
+ nn.GroupNorm(16, 64, eps=1e-06, affine=True),
58
+ # (conv_act):
59
+ nn.SiLU(),
60
+ # (conv_out):
61
+ nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
62
+
63
+ #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14
64
+ #nn.ReLU(),
65
+ #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7
66
+ #nn.ReLU(),
67
+ )
68
+
69
+ self.decoder = nn.Sequential(
70
+ #(conv_in):
71
+ nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
72
+
73
+ #(norm1):
74
+ nn.GroupNorm(16, 64, eps=1e-06, affine=True),
75
+ #(conv1):
76
+ nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
77
+ #(norm2):
78
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
79
+ #(dropout):
80
+ nn.Dropout(p=0.5, inplace=False),
81
+ #(conv2):
82
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
83
+ #(nonlinearity):
84
+ nn.SiLU(),
85
+
86
+ #(upsamplers):
87
+ nn.Upsample(scale_factor=2, mode='nearest'), # 14,14
88
+
89
+ #(norm1):
90
+ nn.GroupNorm(8, 32, eps=1e-06, affine=True),
91
+ #(conv1):
92
+ nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
93
+ #(norm2):
94
+ nn.GroupNorm(8, 16, eps=1e-06, affine=True),
95
+ #(dropout):
96
+ nn.Dropout(p=0.5, inplace=False),
97
+ #(conv2):
98
+ nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
99
+ #(nonlinearity):
100
+ nn.SiLU(),
101
+
102
+ #(upsamplers):
103
+ nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28
104
+
105
+ #(norm1):
106
+ nn.GroupNorm(8, 16, eps=1e-06, affine=True),
107
+ #(conv1):
108
+ nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
109
+
110
+ nn.Sigmoid()
111
+ )
112
+
113
+ def forward(self, xb, yb):
114
+ x = self.encoder(xb)
115
+ #print("current:",x.shape)
116
+ x = self.decoder(x)
117
+ #print("current decoder:",x.shape)
118
+ #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True)
119
+ #print(x.shape, xb.shape)
120
+ return x, F.mse_loss(x, xb)
121
+
122
+ class MyCLIP(nn.Module):
123
+ def __init__(self, n_classes, emb_dim, img_encoder):
124
+ super().__init__()
125
+ self.n_classes = n_classes
126
+ self.emb_dim = emb_dim
127
+ self.text_encoder = nn.Embedding(self.n_classes, self.emb_dim)
128
+ self.img_encoder = img_encoder
129
+
130
+ def forward(self, img, label):
131
+ img_bs = img.shape[0]
132
+ text_emb = self.text_encoder(label)
133
+ img_emb = self.img_encoder(img).view(img_bs, -1)
134
+ logits = text_emb @ (img_emb.T)
135
+ return logits
136
+
137
+
138
+ data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()]))
139
+
140
+ labels_map = {
141
+ 0: "T-Shirt",
142
+ 1: "Trouser",
143
+ 2: "Pullover",
144
+ 3: "Dress",
145
+ 4: "Coat",
146
+ 5: "Sandal",
147
+ 6: "Shirt",
148
+ 7: "Sneaker",
149
+ 8: "Bag",
150
+ 9: "Ankle Boot",
151
+ }
152
+
153
+ clip = torch.load("myclip.pt", map_location=torch.device('cpu')).to("cpu")
154
+ clip.eval()
155
+
156
+ @torch.no_grad()
157
+ def generate():
158
+ dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4)
159
+
160
+ image_eval, label_eval = next(iter(dl_test))
161
+ logits = clip(image_eval,torch.arange(len(labels_map)))
162
+ probability = torch.nn.functional.softmax(logits.T, dim=1)[-1]
163
+ n_topk = 3
164
+ topk = probability.topk(n_topk, dim=-1)
165
+ result = "Predictions (top 3):\n"
166
+ print(topk.indices)
167
+ for idx in range(n_topk):
168
+ print(topk.indices[idx].item())
169
+ label = labels_map[topk.indices[idx].item()]
170
+ prob = topk.values[idx].item()
171
+ print(prob)
172
+ label = label + ":"
173
+ label = f'{label: <12}'
174
+ result = result + label + " " + f'{prob*100:.2f}' + "%\n"
175
+
176
+
177
+ return image_eval[0].squeeze().detach().numpy(), result
178
+
179
+ with gr.Blocks() as demo:
180
+ gr.HTML("""<h1 align="center">CLIP Model</h1>""")
181
+ gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
182
+ session_data = gr.State([])
183
+
184
+ sampling_button = gr.Button("Random image and zero-shot classification")
185
+
186
+ with gr.Row():
187
+ with gr.Column(scale=1):
188
+ gr.HTML("""<h3 align="left">Random image</h1>""")
189
+ gr_image = gr.Image(height=250,width=200)
190
+ with gr.Column(scale=2):
191
+ gr.HTML("""<h3 align="left">Classification</h1>""")
192
+ gr_text = gr.Text(label="Classification")
193
+
194
+
195
+ sampling_button.click(
196
+ generate,
197
+ [],
198
+ [gr_image, gr_text],
199
+ )
200
+
201
+ demo.queue().launch(share=False, inbrowser=True)
myclip.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69c2914eabb705f6bb12581d8cef9bc4a1ac8c291d3da0a4ca1248323b132729
3
+ size 511526
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ torchvision
3
+ diffusers