CLIP_Model / app.py
wb-droid's picture
initial commit.
56258fe
import gradio as gr
import torch
from torch import nn
import torchvision
from torchvision.transforms import ToTensor
from types import SimpleNamespace
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import ToTensor, Pad
class MyVAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
# (conv_in)
nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), # 28, 28
# (down_block_0)
# (norm1)
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (conv1)
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (norm2):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (dropout):
nn.Dropout(p=0.5, inplace=False),
# (conv2):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (downsamplers)(conv):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #14, 14
# (down_block_1)
# (norm1)
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
# (conv1)
nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (norm2):
nn.GroupNorm(8, 64, eps=1e-06, affine=True),
# (dropout):
nn.Dropout(p=0.5, inplace=False),
# (conv2):
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (conv_shortcut):
#nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), #28, 28
# (nonlinearity):
nn.SiLU(),
# (downsamplers)(conv):
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), #7, 7
# (conv_norm_out):
nn.GroupNorm(16, 64, eps=1e-06, affine=True),
# (conv_act):
nn.SiLU(),
# (conv_out):
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=3//2), # 14*14
#nn.ReLU(),
#nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=3//2), # 7*7
#nn.ReLU(),
)
self.decoder = nn.Sequential(
#(conv_in):
nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm1):
nn.GroupNorm(16, 64, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm2):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
#(dropout):
nn.Dropout(p=0.5, inplace=False),
#(conv2):
nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(nonlinearity):
nn.SiLU(),
#(upsamplers):
nn.Upsample(scale_factor=2, mode='nearest'), # 14,14
#(norm1):
nn.GroupNorm(8, 32, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(norm2):
nn.GroupNorm(8, 16, eps=1e-06, affine=True),
#(dropout):
nn.Dropout(p=0.5, inplace=False),
#(conv2):
nn.Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
#(nonlinearity):
nn.SiLU(),
#(upsamplers):
nn.Upsample(scale_factor=2, mode='nearest'), # 16, 28, 28
#(norm1):
nn.GroupNorm(8, 16, eps=1e-06, affine=True),
#(conv1):
nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.Sigmoid()
)
def forward(self, xb, yb):
x = self.encoder(xb)
#print("current:",x.shape)
x = self.decoder(x)
#print("current decoder:",x.shape)
#x = x.flatten(start_dim=1).mean(dim=1, keepdim=True)
#print(x.shape, xb.shape)
return x, F.mse_loss(x, xb)
class MyCLIP(nn.Module):
def __init__(self, n_classes, emb_dim, img_encoder):
super().__init__()
self.n_classes = n_classes
self.emb_dim = emb_dim
self.text_encoder = nn.Embedding(self.n_classes, self.emb_dim)
self.img_encoder = img_encoder
def forward(self, img, label):
img_bs = img.shape[0]
text_emb = self.text_encoder(label)
img_emb = self.img_encoder(img).view(img_bs, -1)
logits = text_emb @ (img_emb.T)
return logits
data_test = torchvision.datasets.FashionMNIST(root='./data/', train=False, download=True, transform=transforms.Compose([Pad([2,2,2,2]), ToTensor()]))
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
clip = torch.load("myclip.pt", map_location=torch.device('cpu')).to("cpu")
clip.eval()
@torch.no_grad()
def generate():
dl_test = torch.utils.data.DataLoader(data_test, batch_size=1, shuffle=True, num_workers=4)
image_eval, label_eval = next(iter(dl_test))
logits = clip(image_eval,torch.arange(len(labels_map)))
probability = torch.nn.functional.softmax(logits.T, dim=1)[-1]
n_topk = 3
topk = probability.topk(n_topk, dim=-1)
result = "Predictions (top 3):\n"
print(topk.indices)
for idx in range(n_topk):
print(topk.indices[idx].item())
label = labels_map[topk.indices[idx].item()]
prob = topk.values[idx].item()
print(prob)
label = label + ":"
label = f'{label: <12}'
result = result + label + " " + f'{prob*100:.2f}' + "%\n"
return image_eval[0].squeeze().detach().numpy(), result
with gr.Blocks() as demo:
gr.HTML("""<h1 align="center">CLIP Model</h1>""")
gr.HTML("""<h1 align="center">trained with FashionMNIST</h1>""")
session_data = gr.State([])
sampling_button = gr.Button("Random image and zero-shot classification")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""<h3 align="left">Random image</h1>""")
gr_image = gr.Image(height=250,width=200)
with gr.Column(scale=2):
gr.HTML("""<h3 align="left">Classification</h1>""")
gr_text = gr.Text(label="Classification")
sampling_button.click(
generate,
[],
[gr_image, gr_text],
)
demo.queue().launch(share=False, inbrowser=True)