classification / utils.py
jtt
update
a919b01
import os
import yaml
import torch
import math
import numpy as np
import clip
from datasets.imagenet import ImageNet
from datasets import build_dataset
from datasets.utils import build_data_loader, AugMixAugmenter
import torchvision.transforms as transforms
from PIL import Image
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
def get_entropy(loss, clip_weights):
max_entropy = math.log2(clip_weights.size(1))
return float(loss / max_entropy)
def softmax_entropy(x):
return -(x.softmax(1) * x.log_softmax(1)).sum(1)
def avg_entropy(outputs):
logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
min_real = torch.finfo(avg_logits.dtype).min
avg_logits = torch.clamp(avg_logits, min=min_real)
return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)
def cls_acc(output, target, topk=1):
pred = output.topk(topk, 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
acc = 100 * acc / target.shape[0]
return acc
def clip_classifier(classnames, template, clip_model):
with torch.no_grad():
clip_weights = []
for classname in classnames:
# Tokenize the prompts
classname = classname.replace('_', ' ')
texts = [t.format(classname) for t in template]
texts = clip.tokenize(texts).cuda()
# prompt ensemble for ImageNet
class_embeddings = clip_model.encode_text(texts)
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
class_embedding = class_embeddings.mean(dim=0)
class_embedding /= class_embedding.norm()
clip_weights.append(class_embedding)
clip_weights = torch.stack(clip_weights, dim=1).cuda()
return clip_weights
def get_clip_logits(images, clip_model, clip_weights):
with torch.no_grad():
if isinstance(images, list):
images = torch.cat(images, dim=0).cuda()
else:
images = images.cuda()
image_features = clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
clip_logits = 100. * image_features @ clip_weights
if image_features.size(0) > 1:
batch_entropy = softmax_entropy(clip_logits)
selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)]
output = clip_logits[selected_idx]
image_features = image_features[selected_idx].mean(0).unsqueeze(0)
clip_logits = output.mean(0).unsqueeze(0)
loss = avg_entropy(output)
prob_map = output.softmax(1).mean(0).unsqueeze(0)
pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t())
else:
loss = softmax_entropy(clip_logits)
prob_map = clip_logits.softmax(1)
pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0])
return image_features, clip_logits, loss, prob_map, pred
def get_clip_logits_aug(images, clip_model, clip_weights):
with torch.no_grad():
if isinstance(images, list):
images = torch.cat(images, dim=0).cuda()
else:
images = images.cuda()
image_features = clip_model.encode_image(images)
image_features /= image_features.norm(dim=-1, keepdim=True)
clip_logits = 100. * image_features @ clip_weights
if image_features.size(0) > 1:
batch_entropy = softmax_entropy(clip_logits)
selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)]
output = clip_logits[selected_idx]
image_features = image_features[selected_idx]
clip_logits = output.mean(0).unsqueeze(0)
loss = avg_entropy(output)
prob_map = output.softmax(1)
pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t())
else:
loss = softmax_entropy(clip_logits)
prob_map = clip_logits.softmax(1)
pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0])
return image_features, clip_logits, loss, prob_map, pred
def get_ood_preprocess():
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
base_transform = transforms.Compose([
transforms.Resize(224, interpolation=BICUBIC),
transforms.CenterCrop(224)])
preprocess = transforms.Compose([
transforms.ToTensor(),
normalize])
aug_preprocess = AugMixAugmenter(base_transform, preprocess, n_views=63, augmix=True)
return aug_preprocess
def get_config_file(config_path, dataset_name):
if dataset_name == "I":
config_name = "imagenet.yaml"
elif dataset_name in ["A", "V", "R", "S"]:
config_name = f"imagenet_{dataset_name.lower()}.yaml"
else:
config_name = f"{dataset_name}.yaml"
config_file = os.path.join(config_path, config_name)
with open(config_file, 'r', encoding='utf-8-sig') as file:
cfg = yaml.load(file, Loader=yaml.SafeLoader)
if not os.path.exists(config_file):
raise FileNotFoundError(f"The configuration file {config_file} was not found.")
return cfg
def build_test_data_loader(dataset_name, root_path, preprocess):
if dataset_name == 'I':
dataset = ImageNet(root_path, preprocess)
test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=4, shuffle=True)
elif dataset_name in ['A','V','R','S']:
preprocess = get_ood_preprocess()
dataset = build_dataset(f"imagenet-{dataset_name.lower()}", root_path)
test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True)
elif dataset_name in ['caltech101','dtd','eurosat','fgvc','food101','oxford_flowers','oxford_pets','stanford_cars','sun397','ucf101']:
dataset = build_dataset(dataset_name, root_path)
test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True)
else:
raise "Dataset is not from the chosen list"
return test_loader, dataset.classnames, dataset.template