| import numpy as np
|
| import torch
|
| import os
|
| from IPython.display import Image
|
| import PIL
|
| import pickle
|
| import clip
|
| import glob
|
|
|
| import argparse
|
|
|
| def parse_args():
|
|
|
| parser = argparse.ArgumentParser(description="")
|
| parser.add_argument(
|
| "--path",
|
| type=str,
|
| default="",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| args = parse_args()
|
|
|
| class ClipWrapper(torch.nn.Module):
|
| def __init__(self, device, model_name='ViT-L/14'):
|
| super(ClipWrapper, self).__init__()
|
| self.clip_model, self.preprocess = clip.load(model_name,
|
| device,
|
| jit=False)
|
| self.clip_model.eval()
|
|
|
| def forward(self, x):
|
| return self.clip_model.encode_image(x)
|
|
|
|
|
| class SimClassifier(torch.nn.Module):
|
| def __init__(self, embeddings, device):
|
| super(SimClassifier, self).__init__()
|
| self.embeddings = torch.nn.parameter.Parameter(embeddings)
|
|
|
| def forward(self, x):
|
| embeddings_norm = self.embeddings / self.embeddings.norm(dim=-1,
|
| keepdim=True)
|
|
|
| image_features_norm = x / x.norm(dim=-1, keepdim=True)
|
|
|
| similarity = (100.0 * image_features_norm @ embeddings_norm.T)
|
|
|
| return similarity.squeeze()
|
|
|
|
|
| def initialize_prompts(clip_model, text_prompts, device):
|
| text = clip.tokenize(text_prompts).to(device)
|
| return clip_model.encode_text(text)
|
|
|
|
|
| def save_prompts(classifier, save_path):
|
| prompts = classifier.embeddings.detach().cpu().numpy()
|
| pickle.dump(prompts, open(save_path, 'wb'))
|
|
|
|
|
| def load_prompts(file_path, device):
|
| return torch.HalfTensor(pickle.load(open(file_path, 'rb'))).to(device)
|
|
|
| def compute_embeddings(image_paths):
|
| images = [clip.preprocess(PIL.Image.open(image_path)) for image_path in image_paths]
|
| images = torch.stack(images).to(device)
|
| return clip(images).half()
|
|
|
|
|
| device='cuda'
|
| prompt_path = 'Eval/prompts.p'
|
| trained_prompts = load_prompts(prompt_path, device=device)
|
|
|
| clip = ClipWrapper(device)
|
| classifier = SimClassifier(trained_prompts, device)
|
|
|
| image_paths = glob.glob(args.path + '/*.png')
|
| print(len(image_paths))
|
|
|
|
|
|
|
| result = []
|
|
|
| for img in image_paths:
|
|
|
|
|
| x = compute_embeddings([img])
|
| y = classifier(x)
|
| y = torch.argmax(y, dim=0)
|
| result.append(y.detach().cpu().numpy())
|
|
|
| result = np.array(result)
|
|
|
| print(len(np.where(result==1)[0])/len(image_paths))
|
|
|
|
|
|
|