import gradio as gr import torch from torch import nn import torchvision from torchvision.transforms import ToTensor from types import SimpleNamespace import matplotlib.pyplot as plt from torchvision import transforms from torchvision.transforms import ToTensor, Pad class MyVAE(nn.Module): def __init__(self): super().__init__() self.encoder = nn.Sequential( # (conv_in) nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28 # (down_block_0) # (norm1) nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (conv1) nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (norm2): nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (dropout): nn.Dropout(p=0.5, inplace=False), # (conv2): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (downsamplers)(conv): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14 # (down_block_1) # (norm1) nn.GroupNorm(8, 32, eps=1e-06, affine=True), # (conv1) nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (norm2): nn.GroupNorm(8, 64, eps=1e-06, affine=True), # (dropout): nn.Dropout(p=0.5, inplace=False), # (conv2): nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (conv_shortcut): #nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28 # (nonlinearity): nn.SiLU(), # (downsamplers)(conv): nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7 # (conv_norm_out): nn.GroupNorm(16, 64, eps=1e-06, affine=True), # (conv_act): nn.SiLU(), # (conv_out): nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14 #nn.ReLU(), #nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7 #nn.ReLU(), ) self.decoder = nn.Sequential( #(conv_in): nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm1): nn.GroupNorm(16, 64, eps=1e-06, affine=True), #(conv1): nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm2): nn.GroupNorm(8, 32, eps=1e-06, affine=True), #(dropout): nn.Dropout(p=0.5, inplace=False), #(conv2): nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(nonlinearity): nn.SiLU(), #(upsamplers): nn.Upsample(scale_factor=2, mode='nearest'), # 14,14 #(norm1): nn.GroupNorm(8, 32, eps=1e-06, affine=True), #(conv1): nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(norm2): nn.GroupNorm(8, 16, eps=1e-06, affine=True), #(dropout): nn.Dropout(p=0.5, inplace=False), #(conv2): nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #(nonlinearity): nn.SiLU(), #(upsamplers): nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28 #(norm1): nn.GroupNorm(8, 16, eps=1e-06, affine=True), #(conv1): nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), nn.Sigmoid() ) def forward(self, xb, yb): x = self.encoder(xb) #print("current:",x.shape) x = self.decoder(x) #print("current decoder:",x.shape) #x = x.flatten(start_dim=1).mean(dim=1, keepdim=True) #print(x.shape, xb.shape) return x, F.mse_loss(x, xb) class MyCLIP(nn.Module): def __init__(self, n_classes, emb_dim, img_encoder): super().__init__() self.n_classes = n_classes self.emb_dim = emb_dim self.text_encoder = nn.Embedding(self.n_classes, self.emb_dim) self.img_encoder = img_encoder def forward(self, img, label): img_bs = img.shape[0] text_emb = self.text_encoder(label) img_emb = self.img_encoder(img).view(img_bs, -1) logits = text_emb @ (img_emb.T) return logits data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()])) labels_map = { 0: "T-Shirt", 1: "Trouser", 2: "Pullover", 3: "Dress", 4: "Coat", 5: "Sandal", 6: "Shirt", 7: "Sneaker", 8: "Bag", 9: "Ankle Boot", } clip = torch.load("myclip.pt", map_location=torch.device('cpu')).to("cpu") clip.eval() @torch.no_grad() def generate(): dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4) image_eval, label_eval = next(iter(dl_test)) logits = clip(image_eval,torch.arange(len(labels_map))) probability = torch.nn.functional.softmax(logits.T, dim=1)[-1] n_topk = 3 topk = probability.topk(n_topk, dim=-1) result = "Predictions (top 3):\n" print(topk.indices) for idx in range(n_topk): print(topk.indices[idx].item()) label = labels_map[topk.indices[idx].item()] prob = topk.values[idx].item() print(prob) label = label + ":" label = f'{label: <12}' result = result + label + " " + f'{prob*100:.2f}' + "%\n" return image_eval[0].squeeze().detach().numpy(), result with gr.Blocks() as demo: gr.HTML("""