wb-droid commited on
Commit
493aa40
·
1 Parent(s): 8a38d1f

Initial commit.

Browse files
Files changed (3) hide show
  1. app.py +205 -0
  2. requirements.txt +5 -0
  3. vit01.pt +3 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from einops import rearrange
3
+ import torch
4
+ from torch import nn
5
+ import torchvision
6
+ from torchvision import transforms
7
+ from torchvision.transforms import ToTensor, Pad
8
+
9
+ labels_map = {
10
+ 0: "T-Shirt",
11
+ 1: "Trouser",
12
+ 2: "Pullover",
13
+ 3: "Dress",
14
+ 4: "Coat",
15
+ 5: "Sandal",
16
+ 6: "Shirt",
17
+ 7: "Sneaker",
18
+ 8: "Bag",
19
+ 9: "Ankle Boot",
20
+ }
21
+ device = "cpu"
22
+
23
+ class Transformer_dummy(nn.Module):
24
+ def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=2 ):
25
+ super().__init__()
26
+
27
+ def forward(self, x):
28
+ return x
29
+
30
+ class MyViT(nn.Module):
31
+ def __init__(self, image_size, patch_size, dim, n_classes = len(labels_map), device = device, depth=5):
32
+ super().__init__()
33
+ self.image_size = image_size #height == width
34
+ self.patch_size = patch_size #height == width
35
+ self.dim = dim # dim of latent space for each patch
36
+ self.n_classes = n_classes
37
+
38
+ self.nh = self.nw = image_size // patch_size
39
+ self.n_patches = self.nh * self.nw # number or patches, i.e. NLP's seq len
40
+
41
+ self.layernorm1 = nn.LayerNorm(self.patch_size**2)
42
+ self.ln = nn.Linear(self.patch_size**2, dim)
43
+ self.layernorm2 = nn.LayerNorm(dim)
44
+ self.pos_encoding = nn.Embedding(self.n_patches, self.dim)
45
+ self.transformer = Transformer(dim=self.dim, depth=depth)
46
+
47
+
48
+ #self.proj = nn.Linear(self.dim * self.n_patches, self.n_classes)
49
+ self.proj = nn.Linear(self.dim, self.n_classes)
50
+
51
+ def forward(self, x):
52
+ # rearrange 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)'
53
+ x = rearrange(x, 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)', nh=self.nh, nw=self.nw)
54
+ # rearrange 'b nh nw d -> b (nh nw) d'
55
+ x = rearrange(x, 'b nh nw d -> b (nh nw) d')
56
+
57
+
58
+ x = self.layernorm1(x)
59
+ x = self.ln(x) #(b n_patches patch_size*patch_size) -> (b n_patches dim)
60
+ x = self.layernorm2(x)
61
+
62
+ pos = self.pos_encoding(torch.arange(0, self.n_patches).to(device))
63
+
64
+ x = x + pos
65
+
66
+ x = self.transformer(x)
67
+
68
+ #x = self.proj(x.view(x.shape[0],-1))
69
+ x = self.proj(x.mean(dim=1))
70
+
71
+ return x
72
+
73
+ class MLPBlock(nn.Module):
74
+ def __init__(self, dim, mlp_hidden_dim=4096, dropout=0.):
75
+ super().__init__()
76
+ self.layernorm = nn.LayerNorm(dim)
77
+ self.dropout = nn.Dropout(dropout)
78
+ self.dropout2 = nn.Dropout(dropout)
79
+ self.proj1 = nn.Linear(dim, mlp_hidden_dim)
80
+ self.proj2 = nn.Linear(mlp_hidden_dim, dim)
81
+ self.activation = nn.GELU()
82
+
83
+ def forward(self, x):
84
+ x = self.layernorm(x)
85
+
86
+ x = self.proj1(x)
87
+ x = self.activation(x)
88
+ x = self.dropout(x)
89
+ x = self.proj2(x)
90
+ x = self.dropout2(x)
91
+
92
+ return x
93
+
94
+ class AttentionBlock(nn.Module):
95
+ def __init__(self, dim, attention_heads = 8, depth=2, dropout=0.):
96
+ super().__init__()
97
+ self.dim = dim
98
+ self.attention_heads = attention_heads
99
+
100
+ self.layernorm = nn.LayerNorm(dim)
101
+ self.proj = nn.Linear(dim, 3*dim)
102
+ self.attention = nn.Softmax(dim = -1)
103
+ self.drop = nn.Dropout(dropout)
104
+
105
+ def forward(self, x):
106
+ x = self.layernorm(x)
107
+ q,k,v = self.proj(x).chunk(3, dim=-1)
108
+
109
+ # rearrange to b, num_heads, seq, head_size
110
+ q = rearrange(q, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads)
111
+ k = rearrange(k, 'b s (nh hs) -> b nh hs s', nh = self.attention_heads)
112
+ v = rearrange(v, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads)
113
+
114
+ # attention q@kT
115
+ x = q@k
116
+
117
+ # scale
118
+ x = x * (k.shape[-1] ** -0.5)
119
+
120
+ # attention mask not needed
121
+ #x = x.mask_fill(torch.ones((1,1, k.shape[-1], k.shape[-1])).tril())
122
+
123
+ # attention softmax
124
+ x = self.attention(x)
125
+
126
+ # drop out
127
+ x = self.drop(x)
128
+
129
+ # attention q@kT@v
130
+ x = x@v
131
+
132
+ # rearrange to b, seq, (num_heads, head_size)
133
+ x = rearrange(x, 'b nh s hs -> b s (nh hs)', nh = self.attention_heads)
134
+
135
+ return x
136
+
137
+
138
+ class Transformer(nn.Module):
139
+ def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=5 ):
140
+ super().__init__()
141
+ self.layernorm = nn.LayerNorm(dim)
142
+ self.net = nn.ModuleList([AttentionBlock(dim=dim), MLPBlock(dim=dim)] * depth)
143
+
144
+
145
+ def forward(self, x):
146
+ for m in self.net:
147
+ x = x + m(x)
148
+ x = self.layernorm(x)
149
+ return x
150
+
151
+
152
+ data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()]))
153
+
154
+
155
+
156
+ model = torch.load("vit01.pt", map_location=torch.device('cpu')).to("cpu")
157
+ model.eval()
158
+
159
+ @torch.no_grad()
160
+ def generate():
161
+ dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4)
162
+
163
+ image_eval, label_eval = next(iter(dl_test))
164
+ image_eval = image_eval - 0.5
165
+ logits = model(image_eval)
166
+ probability = torch.nn.functional.softmax(logits, dim=1)[-1]
167
+ n_topk = 3
168
+ topk = probability.topk(n_topk, dim=-1)
169
+ result = "Predictions (top 3):\n"
170
+ print(topk.indices)
171
+ for idx in range(n_topk):
172
+ print(topk.indices[idx].item())
173
+ label = labels_map[topk.indices[idx].item()]
174
+ prob = topk.values[idx].item()
175
+ print(prob)
176
+ label = label + ":"
177
+ label = f'{label: <12}'
178
+ result = result + label + " " + f'{prob*100:.2f}' + "%\n"
179
+
180
+
181
+ return (image_eval+0.5)[0].squeeze().detach().numpy(), result
182
+
183
+ with gr.Blocks() as demo:
184
+ gr.HTML("""<h1 align="center">ViT (Vision Transformer) Model</h1>""")
185
+ gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
186
+ session_data = gr.State([])
187
+
188
+ sampling_button = gr.Button("Random image and zero-shot classification")
189
+
190
+ with gr.Row():
191
+ with gr.Column(scale=1):
192
+ gr.HTML("""<h3 align="left">Random image</h1>""")
193
+ gr_image = gr.Image(height=250,width=200)
194
+ with gr.Column(scale=2):
195
+ gr.HTML("""<h3 align="left">Classification</h1>""")
196
+ gr_text = gr.Text(label="Classification")
197
+
198
+
199
+ sampling_button.click(
200
+ generate,
201
+ [],
202
+ [gr_image, gr_text],
203
+ )
204
+
205
+ demo.queue().launch(share=False, inbrowser=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ torchvision
3
+ diffusers
4
+ einops
5
+ torch
vit01.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd9dba8b6c75f7573f6c720b7c950d1ef3ad064c7009ac2517a1328ed7e7dc94
3
+ size 9308389