xmutly's picture
Upload 294 files
e1aaaac verified
"""
Code adapated from https://github.com/mlfoundations/open_clip/blob/main/src/training/zero_shot.py
Thanks to the authors of OpenCLIP
"""
import logging
from contextlib import suppress
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import classification_report, balanced_accuracy_score
from autoattack import AutoAttack
def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True):
"""
This function returns zero-shot vectors for each class in order
to use it for zero-shot classification.
model:
CLIP-like model with `encode_text`
tokenizer:
text tokenizer, i.e. convert list of strings to torch.Tensor of integers
classnames: list of str
name of classes
templates: list of str
templates to use.
Returns
-------
torch.Tensor of shape (N,C) where N is the number
of templates, and C is the number of classes.
"""
autocast = torch.cuda.amp.autocast if amp else suppress
with torch.no_grad(), autocast():
zeroshot_weights = []
for classname in tqdm(classnames):
if type(templates) == dict:
# class-specific prompts (e.g., CuPL https://arxiv.org/abs/2209.03320)
texts = templates[classname]
elif type(templates) == list:
# generic prompts tht are specialized for each class by replacing {c} with the class name
texts = [template.format(c=classname) for template in templates]
else:
raise ValueError("templates must be a list or a dict")
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
def accuracy(output, target, topk=(1,)):
"""
Compute top-k accuracy
output: torch.Tensor
shape (N, C) where N is the number of examples, C the number of classes.
these are the logits.
target: torch.Tensor
shape (N,) where N is the number of examples. Groundtruth class id of each example.
topk: tuple
which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies
Returns
-------
list of top-k accuracies in the same order as `topk`
"""
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
n = len(target)
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
def run_classification(model, classifier, dataloader, device, normalize=None, resize=None, amp=True,
attack_config=None):
"""
Run zero-shot classifcation
model: torch.nn.Module
CLIP-like model with `encode_image` and `encode_text`
classifier: torch.Tensor
obtained from the function `zero_shot_classifier`
dataloader: torch.utils.data.Dataloader
Returns
-------
(pred, true) where
- pred (N, C) are the logits
- true (N,) are the actual classes
"""
assert normalize is not None
autocast = torch.cuda.amp.autocast if amp else suppress
pred = []
true = []
max_samples = attack_config['n_samples']
def _forward_unnorm(data_unnorm):
if resize is not None:
data_unnorm = resize(data_unnorm)
data_norm = normalize(data_unnorm)
features = model.encode_image(data_norm)
features = F.normalize(features, dim=-1)
logits = 100. * features @ classifier
return logits
attack_str = attack_config['attack']
adv = attack_str != 'none'
if adv:
bs = attack_config['bs']
norm = attack_config['norm']
eps = attack_config['eps'] / 255.
# iterations = attack_config['iterations']
if attack_str.lower() == 'aa':
attacks_to_run = ('apgd-ce', 'apgd-t') if len(dataloader.dataset.classes) > 2 else ('apgd-ce',)
attack = AutoAttack(
_forward_unnorm, norm=norm, eps=eps,
attacks_to_run=attacks_to_run,
version='custom',
verbose=True,
device=device
)
all_images, all_targets = [], []
for i, batch in enumerate(dataloader):
all_images.append(batch[0])
all_targets.append(batch[1])
if (max_samples > 0) and (i >= max_samples // bs + 2):
break
all_images = torch.cat(all_images, dim=0)
all_targets = torch.cat(all_targets, dim=0)
if max_samples > 0:
all_images = all_images[:max_samples]
all_targets = all_targets[:max_samples]
assert 0. <= all_images.min() and all_images.max() <= 1., f'{all_images.min()} {all_images.max()}'
print(f'[n samples] {len(all_images)}')
print(f'starting autoattack..')
images = attack.run_standard_evaluation(all_images, all_targets, bs=bs)
print('getting logits..')
with torch.no_grad():
for i in range(0, len(images), bs):
batch = images[i:i + bs]
logits = _forward_unnorm(batch.to(device))
pred.append(logits.float().cpu())
true.append(all_targets[i:i + bs].cpu())
return torch.cat(pred), torch.cat(true)
with torch.no_grad():
n = 0
for images, target in tqdm(dataloader):
if (max_samples > 0) and (n >= max_samples):
break
images = images.to(device)
target = target.to(device)
n += images.shape[0]
with autocast():
logits = _forward_unnorm(images)
true.append(target.cpu())
pred.append(logits.float().cpu())
pred = torch.cat(pred)
true = torch.cat(true)
if max_samples > 0:
pred = pred[:max_samples]
true = true[:max_samples]
print(f'[n samples] {len(pred)}')
return pred, true
def average_precision_per_class(scores, targets):
"""
Compute average precision for each class
this metric is used for multi-label classification
see explanations here https://fangdahan.medium.com/calculate-mean-average-precision-map-for-multi-label-classification-b082679d31be
Code is adapted from https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py, thanks to the authors of `tnt`.
Parameters
----------
scores: torch.Tensor
logits, of shape (N,C) where N is the number of examples, C the number of classes
targets: torch.Tensor
one-hot vectors of groundtruth targets (N, C), where N is the number of examples, C is the
number of classes
Returns
-------
torch.Tensor of shape (C,) of avereage precision for each class, where C is
the number of classes.
"""
ap = torch.zeros(scores.size(1))
rg = torch.arange(1, scores.size(0) + 1).float()
# compute average precision for each class
for k in range(scores.size(1)):
# sort scores
scores_k = scores[:, k]
targets_k = targets[:, k]
_, sortind = torch.sort(scores_k, 0, True)
truth = targets_k[sortind]
tp = truth.float().cumsum(0)
# compute precision curve
precision = tp.div(rg)
# compute average precision
ap[k] = precision[truth.bool()].sum() / max(float(truth.sum()), 1)
return ap
def evaluate(model, dataloader, tokenizer, classnames, templates, device, normalize=None, resize=None,
amp=True, verbose=False, save_clf=None, load_clfs=[], attack_config=None):
"""
Run zero-shot classification and evaluate the metrics
Parameters
----------
model: torch.nn.Module
CLIP-like model with `encode_image` and `encode_text`
dataloader: torch.utils.data.Dataloader
tokenizer: text tokenizer
classnames: list of str
class names
templates: list of str
templates to use for zero-shot classification
device: cpu/cuda
normalize: normalization transform
amp: whether to use automatic mixed precision
verbose: whether to use verbose model
Returns
-------
dict of classification metrics
"""
assert normalize is not None
if len(load_clfs) > 0:
n = len(load_clfs)
classifier = torch.load(load_clfs[0], map_location='cpu') / n
for i in range(1, n):
classifier = classifier + torch.load(load_clfs[i], map_location='cpu') / n
classifier = classifier.to(device)
else:
classifier = zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=amp)
if save_clf is not None:
torch.save(classifier, save_clf)
# exit() - not sure if we want to exit here or not.
logits, target = run_classification(model, classifier, dataloader, device,
normalize=normalize, resize=resize, amp=amp,
attack_config=attack_config)
is_multilabel = (len(target.shape) == 2)
if is_multilabel:
if verbose:
print("Detected a multi-label classification dataset")
# Multiple labels per image, multiple classes on the dataset
ap_per_class = average_precision_per_class(logits, target)
if verbose:
for class_name, ap in zip(dataloader.dataset.classes, ap_per_class.tolist()):
print(f"Class: {class_name}, AveragePrecision: {ap}")
return {"mean_average_precision": ap_per_class.mean().item()}
else:
# Single label per image, multiple classes on the dataset
# just compute accuracy and mean_per_class_recall
pred = logits.argmax(axis=1)
# measure accuracy
if len(dataloader.dataset.classes) >= 5:
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
else:
acc1, = accuracy(logits, target, topk=(1,))
acc5 = float("nan")
mean_per_class_recall = balanced_accuracy_score(target, pred)
if verbose:
pass
# print(classification_report(target, pred, digits=3))
print(f"[acc1] {acc1*100:.2f}")
return {"acc1": acc1, "acc5": acc5, "mean_per_class_recall": mean_per_class_recall}