Spaces:
Sleeping
Sleeping
| import os | |
| import yaml | |
| import torch | |
| import math | |
| import numpy as np | |
| import clip | |
| from datasets.imagenet import ImageNet | |
| from datasets import build_dataset | |
| from datasets.utils import build_data_loader, AugMixAugmenter | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| try: | |
| from torchvision.transforms import InterpolationMode | |
| BICUBIC = InterpolationMode.BICUBIC | |
| except ImportError: | |
| BICUBIC = Image.BICUBIC | |
| def get_entropy(loss, clip_weights): | |
| max_entropy = math.log2(clip_weights.size(1)) | |
| return float(loss / max_entropy) | |
| def softmax_entropy(x): | |
| return -(x.softmax(1) * x.log_softmax(1)).sum(1) | |
| def avg_entropy(outputs): | |
| logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) | |
| avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) | |
| min_real = torch.finfo(avg_logits.dtype).min | |
| avg_logits = torch.clamp(avg_logits, min=min_real) | |
| return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1) | |
| def cls_acc(output, target, topk=1): | |
| pred = output.topk(topk, 1, True, True)[1].t() | |
| correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
| acc = float(correct[: topk].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) | |
| acc = 100 * acc / target.shape[0] | |
| return acc | |
| def clip_classifier(classnames, template, clip_model): | |
| with torch.no_grad(): | |
| clip_weights = [] | |
| for classname in classnames: | |
| # Tokenize the prompts | |
| classname = classname.replace('_', ' ') | |
| texts = [t.format(classname) for t in template] | |
| texts = clip.tokenize(texts).cuda() | |
| # prompt ensemble for ImageNet | |
| class_embeddings = clip_model.encode_text(texts) | |
| class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) | |
| class_embedding = class_embeddings.mean(dim=0) | |
| class_embedding /= class_embedding.norm() | |
| clip_weights.append(class_embedding) | |
| clip_weights = torch.stack(clip_weights, dim=1).cuda() | |
| return clip_weights | |
| def get_clip_logits(images, clip_model, clip_weights): | |
| with torch.no_grad(): | |
| if isinstance(images, list): | |
| images = torch.cat(images, dim=0).cuda() | |
| else: | |
| images = images.cuda() | |
| image_features = clip_model.encode_image(images) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| clip_logits = 100. * image_features @ clip_weights | |
| if image_features.size(0) > 1: | |
| batch_entropy = softmax_entropy(clip_logits) | |
| selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)] | |
| output = clip_logits[selected_idx] | |
| image_features = image_features[selected_idx].mean(0).unsqueeze(0) | |
| clip_logits = output.mean(0).unsqueeze(0) | |
| loss = avg_entropy(output) | |
| prob_map = output.softmax(1).mean(0).unsqueeze(0) | |
| pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t()) | |
| else: | |
| loss = softmax_entropy(clip_logits) | |
| prob_map = clip_logits.softmax(1) | |
| pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0]) | |
| return image_features, clip_logits, loss, prob_map, pred | |
| def get_clip_logits_aug(images, clip_model, clip_weights): | |
| with torch.no_grad(): | |
| if isinstance(images, list): | |
| images = torch.cat(images, dim=0).cuda() | |
| else: | |
| images = images.cuda() | |
| image_features = clip_model.encode_image(images) | |
| image_features /= image_features.norm(dim=-1, keepdim=True) | |
| clip_logits = 100. * image_features @ clip_weights | |
| if image_features.size(0) > 1: | |
| batch_entropy = softmax_entropy(clip_logits) | |
| selected_idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * 0.1)] | |
| output = clip_logits[selected_idx] | |
| image_features = image_features[selected_idx] | |
| clip_logits = output.mean(0).unsqueeze(0) | |
| loss = avg_entropy(output) | |
| prob_map = output.softmax(1) | |
| pred = int(output.mean(0).unsqueeze(0).topk(1, 1, True, True)[1].t()) | |
| else: | |
| loss = softmax_entropy(clip_logits) | |
| prob_map = clip_logits.softmax(1) | |
| pred = int(clip_logits.topk(1, 1, True, True)[1].t()[0]) | |
| return image_features, clip_logits, loss, prob_map, pred | |
| def get_ood_preprocess(): | |
| normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711]) | |
| base_transform = transforms.Compose([ | |
| transforms.Resize(224, interpolation=BICUBIC), | |
| transforms.CenterCrop(224)]) | |
| preprocess = transforms.Compose([ | |
| transforms.ToTensor(), | |
| normalize]) | |
| aug_preprocess = AugMixAugmenter(base_transform, preprocess, n_views=63, augmix=True) | |
| return aug_preprocess | |
| def get_config_file(config_path, dataset_name): | |
| if dataset_name == "I": | |
| config_name = "imagenet.yaml" | |
| elif dataset_name in ["A", "V", "R", "S"]: | |
| config_name = f"imagenet_{dataset_name.lower()}.yaml" | |
| else: | |
| config_name = f"{dataset_name}.yaml" | |
| config_file = os.path.join(config_path, config_name) | |
| with open(config_file, 'r', encoding='utf-8-sig') as file: | |
| cfg = yaml.load(file, Loader=yaml.SafeLoader) | |
| if not os.path.exists(config_file): | |
| raise FileNotFoundError(f"The configuration file {config_file} was not found.") | |
| return cfg | |
| def build_test_data_loader(dataset_name, root_path, preprocess): | |
| if dataset_name == 'I': | |
| dataset = ImageNet(root_path, preprocess) | |
| test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=1, num_workers=4, shuffle=True) | |
| elif dataset_name in ['A','V','R','S']: | |
| preprocess = get_ood_preprocess() | |
| dataset = build_dataset(f"imagenet-{dataset_name.lower()}", root_path) | |
| test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) | |
| elif dataset_name in ['caltech101','dtd','eurosat','fgvc','food101','oxford_flowers','oxford_pets','stanford_cars','sun397','ucf101']: | |
| dataset = build_dataset(dataset_name, root_path) | |
| test_loader = build_data_loader(data_source=dataset.test, batch_size=1, is_train=False, tfm=preprocess, shuffle=True) | |
| else: | |
| raise "Dataset is not from the chosen list" | |
| return test_loader, dataset.classnames, dataset.template |