Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from einops import rearrange | |
| import torch | |
| from torch import nn | |
| import torchvision | |
| from torchvision import transforms | |
| from torchvision.transforms import ToTensor, Pad | |
| labels_map = { | |
| 0: "T-Shirt", | |
| 1: "Trouser", | |
| 2: "Pullover", | |
| 3: "Dress", | |
| 4: "Coat", | |
| 5: "Sandal", | |
| 6: "Shirt", | |
| 7: "Sneaker", | |
| 8: "Bag", | |
| 9: "Ankle Boot", | |
| } | |
| device = "cpu" | |
| class Transformer_dummy(nn.Module): | |
| def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=2 ): | |
| super().__init__() | |
| def forward(self, x): | |
| return x | |
| class MyViT(nn.Module): | |
| def __init__(self, image_size, patch_size, dim, n_classes = len(labels_map), device = device, depth=5): | |
| super().__init__() | |
| self.image_size = image_size #height == width | |
| self.patch_size = patch_size #height == width | |
| self.dim = dim # dim of latent space for each patch | |
| self.n_classes = n_classes | |
| self.nh = self.nw = image_size // patch_size | |
| self.n_patches = self.nh * self.nw # number or patches, i.e. NLP's seq len | |
| self.layernorm1 = nn.LayerNorm(self.patch_size**2) | |
| self.ln = nn.Linear(self.patch_size**2, dim) | |
| self.layernorm2 = nn.LayerNorm(dim) | |
| self.pos_encoding = nn.Embedding(self.n_patches, self.dim) | |
| self.transformer = Transformer(dim=self.dim, depth=depth) | |
| #self.proj = nn.Linear(self.dim * self.n_patches, self.n_classes) | |
| self.proj = nn.Linear(self.dim, self.n_classes) | |
| def forward(self, x): | |
| # rearrange 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)' | |
| x = rearrange(x, 'b c (nh ph) (nw pw) -> b nh nw (c ph pw)', nh=self.nh, nw=self.nw) | |
| # rearrange 'b nh nw d -> b (nh nw) d' | |
| x = rearrange(x, 'b nh nw d -> b (nh nw) d') | |
| x = self.layernorm1(x) | |
| x = self.ln(x) #(b n_patches patch_size*patch_size) -> (b n_patches dim) | |
| x = self.layernorm2(x) | |
| pos = self.pos_encoding(torch.arange(0, self.n_patches).to(device)) | |
| x = x + pos | |
| x = self.transformer(x) | |
| #x = self.proj(x.view(x.shape[0],-1)) | |
| x = self.proj(x.mean(dim=1)) | |
| return x | |
| class MLPBlock(nn.Module): | |
| def __init__(self, dim, mlp_hidden_dim=4096, dropout=0.): | |
| super().__init__() | |
| self.layernorm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.proj1 = nn.Linear(dim, mlp_hidden_dim) | |
| self.proj2 = nn.Linear(mlp_hidden_dim, dim) | |
| self.activation = nn.GELU() | |
| def forward(self, x): | |
| x = self.layernorm(x) | |
| x = self.proj1(x) | |
| x = self.activation(x) | |
| x = self.dropout(x) | |
| x = self.proj2(x) | |
| x = self.dropout2(x) | |
| return x | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, dim, attention_heads = 8, depth=2, dropout=0.): | |
| super().__init__() | |
| self.dim = dim | |
| self.attention_heads = attention_heads | |
| self.layernorm = nn.LayerNorm(dim) | |
| self.proj = nn.Linear(dim, 3*dim) | |
| self.attention = nn.Softmax(dim = -1) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x): | |
| x = self.layernorm(x) | |
| q,k,v = self.proj(x).chunk(3, dim=-1) | |
| # rearrange to b, num_heads, seq, head_size | |
| q = rearrange(q, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads) | |
| k = rearrange(k, 'b s (nh hs) -> b nh hs s', nh = self.attention_heads) | |
| v = rearrange(v, 'b s (nh hs) -> b nh s hs', nh = self.attention_heads) | |
| # attention q@kT | |
| x = q@k | |
| # scale | |
| x = x * (k.shape[-1] ** -0.5) | |
| # attention mask not needed | |
| #x = x.mask_fill(torch.ones((1,1, k.shape[-1], k.shape[-1])).tril()) | |
| # attention softmax | |
| x = self.attention(x) | |
| # drop out | |
| x = self.drop(x) | |
| # attention q@kT@v | |
| x = x@v | |
| # rearrange to b, seq, (num_heads, head_size) | |
| x = rearrange(x, 'b nh s hs -> b s (nh hs)', nh = self.attention_heads) | |
| return x | |
| class Transformer(nn.Module): | |
| def __init__(self, dim, mlp_hidden_dim=4098, attention_heads=8, depth=5 ): | |
| super().__init__() | |
| self.layernorm = nn.LayerNorm(dim) | |
| self.net = nn.ModuleList([AttentionBlock(dim=dim), MLPBlock(dim=dim)] * depth) | |
| def forward(self, x): | |
| for m in self.net: | |
| x = x + m(x) | |
| x = self.layernorm(x) | |
| return x | |
| data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()])) | |
| model = torch.load("vit01.pt", map_location=torch.device('cpu')).to("cpu") | |
| model.eval() | |
| def generate(): | |
| dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4) | |
| image_eval, label_eval = next(iter(dl_test)) | |
| image_eval = image_eval - 0.5 | |
| logits = model(image_eval) | |
| probability = torch.nn.functional.softmax(logits, dim=1)[-1] | |
| n_topk = 3 | |
| topk = probability.topk(n_topk, dim=-1) | |
| result = "Predictions (top 3):\n" | |
| print(topk.indices) | |
| for idx in range(n_topk): | |
| print(topk.indices[idx].item()) | |
| label = labels_map[topk.indices[idx].item()] | |
| prob = topk.values[idx].item() | |
| print(prob) | |
| label = label + ":" | |
| label = f'{label: <12}' | |
| result = result + label + " " + f'{prob*100:.2f}' + "%\n" | |
| return (image_eval+0.5)[0].squeeze().detach().numpy(), result | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">ViT (Vision Transformer) Model</h1>""") | |
| gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") | |
| session_data = gr.State([]) | |
| sampling_button = gr.Button("Random image and zero-shot classification") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("""<h3 align="left">Random image</h1>""") | |
| gr_image = gr.Image(height=250,width=200) | |
| with gr.Column(scale=2): | |
| gr.HTML("""<h3 align="left">Classification</h1>""") | |
| gr_text = gr.Text(label="Classification") | |
| sampling_button.click( | |
| generate, | |
| [], | |
| [gr_image, gr_text], | |
| ) | |
| demo.queue().launch(share=False, inbrowser=True) | |