Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from torch import nn | |
| import torchvision | |
| from torchvision.transforms import ToTensor | |
| from types import SimpleNamespace | |
| import matplotlib.pyplot as plt | |
| from torchvision import transforms | |
| from torchvision.transforms import ToTensor, Pad | |
| class MyVAE(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| # (conv_in) | |
| nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28 | |
| # (down_block_0) | |
| # (norm1) | |
| nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
| # (conv1) | |
| nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
| # (norm2): | |
| nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
| # (dropout): | |
| nn.Dropout(p=0.5, inplace=False), | |
| # (conv2): | |
| nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
| # (nonlinearity): | |
| nn.SiLU(), | |
| # (downsamplers)(conv): | |
| nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14 | |
| # (down_block_1) | |
| # (norm1) | |
| nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
| # (conv1) | |
| nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
| # (norm2): | |
| nn.GroupNorm(8, 64, eps=1e-06, affine=True), | |
| # (dropout): | |
| nn.Dropout(p=0.5, inplace=False), | |
| # (conv2): | |
| nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
| # (nonlinearity): | |
| nn.SiLU(), | |
| # (conv_shortcut): | |
| #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 | |
| # (nonlinearity): | |
| nn.SiLU(), | |
| # (downsamplers)(conv): | |
| nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7 | |
| # (conv_norm_out): | |
| nn.GroupNorm(16, 64, eps=1e-06, affine=True), | |
| # (conv_act): | |
| nn.SiLU(), | |
| # (conv_out): | |
| nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14 | |
| #nn.ReLU(), | |
| #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7 | |
| #nn.ReLU(), | |
| ) | |
| self.decoder = nn.Sequential( | |
| #(conv_in): | |
| nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #(norm1): | |
| nn.GroupNorm(16, 64, eps=1e-06, affine=True), | |
| #(conv1): | |
| nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #(norm2): | |
| nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
| #(dropout): | |
| nn.Dropout(p=0.5, inplace=False), | |
| #(conv2): | |
| nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #(nonlinearity): | |
| nn.SiLU(), | |
| #(upsamplers): | |
| nn.Upsample(scale_factor=2, mode='nearest'), # 14,14 | |
| #(norm1): | |
| nn.GroupNorm(8, 32, eps=1e-06, affine=True), | |
| #(conv1): | |
| nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #(norm2): | |
| nn.GroupNorm(8, 16, eps=1e-06, affine=True), | |
| #(dropout): | |
| nn.Dropout(p=0.5, inplace=False), | |
| #(conv2): | |
| nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| #(nonlinearity): | |
| nn.SiLU(), | |
| #(upsamplers): | |
| nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28 | |
| #(norm1): | |
| nn.GroupNorm(8, 16, eps=1e-06, affine=True), | |
| #(conv1): | |
| nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, xb, yb): | |
| x = self.encoder(xb) | |
| #print("current:",x.shape) | |
| x = self.decoder(x) | |
| #print("current decoder:",x.shape) | |
| #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True) | |
| #print(x.shape, xb.shape) | |
| return x, F.mse_loss(x, xb) | |
| class MyCLIP(nn.Module): | |
| def __init__(self, n_classes, emb_dim, img_encoder): | |
| super().__init__() | |
| self.n_classes = n_classes | |
| self.emb_dim = emb_dim | |
| self.text_encoder = nn.Embedding(self.n_classes, self.emb_dim) | |
| self.img_encoder = img_encoder | |
| def forward(self, img, label): | |
| img_bs = img.shape[0] | |
| text_emb = self.text_encoder(label) | |
| img_emb = self.img_encoder(img).view(img_bs, -1) | |
| logits = text_emb @ (img_emb.T) | |
| return logits | |
| data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()])) | |
| labels_map = { | |
| 0: "T-Shirt", | |
| 1: "Trouser", | |
| 2: "Pullover", | |
| 3: "Dress", | |
| 4: "Coat", | |
| 5: "Sandal", | |
| 6: "Shirt", | |
| 7: "Sneaker", | |
| 8: "Bag", | |
| 9: "Ankle Boot", | |
| } | |
| clip = torch.load("myclip.pt", map_location=torch.device('cpu')).to("cpu") | |
| clip.eval() | |
| def generate(): | |
| dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4) | |
| image_eval, label_eval = next(iter(dl_test)) | |
| logits = clip(image_eval,torch.arange(len(labels_map))) | |
| probability = torch.nn.functional.softmax(logits.T, dim=1)[-1] | |
| n_topk = 3 | |
| topk = probability.topk(n_topk, dim=-1) | |
| result = "Predictions (top 3):\n" | |
| print(topk.indices) | |
| for idx in range(n_topk): | |
| print(topk.indices[idx].item()) | |
| label = labels_map[topk.indices[idx].item()] | |
| prob = topk.values[idx].item() | |
| print(prob) | |
| label = label + ":" | |
| label = f'{label: <12}' | |
| result = result + label + " " + f'{prob*100:.2f}' + "%\n" | |
| return image_eval[0].squeeze().detach().numpy(), result | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">CLIP Model</h1>""") | |
| gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""") | |
| session_data = gr.State([]) | |
| sampling_button = gr.Button("Random image and zero-shot classification") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML("""<h3 align="left">Random image</h1>""") | |
| gr_image = gr.Image(height=250,width=200) | |
| with gr.Column(scale=2): | |
| gr.HTML("""<h3 align="left">Classification</h1>""") | |
| gr_text = gr.Text(label="Classification") | |
| sampling_button.click( | |
| generate, | |
| [], | |
| [gr_image, gr_text], | |
| ) | |
| demo.queue().launch(share=False, inbrowser=True) | |