xmutly's picture
Upload 294 files
e1aaaac verified
import logging
from contextlib import suppress
import torch
import torch.nn.functional as F
from tqdm import tqdm
def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]):
"""
Evaluate the model on the given dataset
Parameters
----------
model: torch.nn,Module
CLIP-like model with `encode_image` and `encode_text`
dataloader: torch.utils.data.Dataloader
dataloader to use for evaluation
tokenizer:
text tokenizer, i.e. convert list of strings to torch.Tensor of integers
device: cpu/cuda
amp: whether to use automatic mixed precision
Returns
-------
dict of accuracy metric
"""
autocast = torch.cuda.amp.autocast if amp else suppress
preds = []
for batch_images, batch_texts in tqdm(dataloader):
batch_images = batch_images.to(device)
# tokenize all texts in the batch
batch_texts_tok = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
nb_texts_for_each_image = [len(texts) for texts in batch_texts]
# compute the embedding of images and texts
with torch.no_grad(), autocast():
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1).cpu()
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1).cpu()
start = 0
for i, nb in enumerate(nb_texts_for_each_image):
end = start + nb
image_emb = batch_images_emb[i:i+1]
texts_emb = batch_texts_emb[start:end]
scores = image_emb @ texts_emb.t()
scores = scores[0]
pred = scores.argmax().item()
start = end
preds.append(pred)
pred = torch.Tensor(preds).long()
acc = (pred==0).float().mean().item() # 0 is the index of the caption, the rest (>0) are considered negative captions
metrics = {}
metrics[f"acc"] = acc
return metrics