| | import torch |
| | import clip |
| | from PIL import Image |
| |
|
| | from pdb import set_trace as st |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | model, preprocess = clip.load("ViT-B/16", device=device) |
| |
|
| | image = preprocess(Image.open("utils.torch_utils/CLIP.png")).unsqueeze(0).to(device) |
| | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | with torch.no_grad(): |
| | x = image.type(model.dtype) |
| | self = model.visual |
| | x = self.conv1(x) |
| | x = x.reshape(x.shape[0], x.shape[1], -1) |
| | x = x.permute(0, 2, 1) |
| | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
| | x = x + self.positional_embedding.to(x.dtype) |
| | x = self.ln_pre(x) |
| |
|
| | x = x.permute(1, 0, 2) |
| | x = self.transformer(x) |
| | x = x.permute(1, 0, 2) |
| | st() |
| |
|
| | pass |
| |
|
| |
|
| | print("Label probs:", probs) |