|
|
import shutil |
|
|
import os |
|
|
import pickle |
|
|
import torch |
|
|
import numpy as np |
|
|
import torchvision.transforms as transforms |
|
|
from torchvision.datasets import * |
|
|
from typing import Any, Callable, Optional, Tuple |
|
|
from PIL import Image |
|
|
|
|
|
def reset_log_file(file_path): |
|
|
"""Reset the log file at the start of the program.""" |
|
|
with open(file_path, 'w') as f: |
|
|
f.write('') |
|
|
|
|
|
def log_record(message, file_path): |
|
|
"""Log messages to both console and file in append mode.""" |
|
|
print(message) |
|
|
with open(file_path, 'a') as f: |
|
|
f.write(message + '\n') |
|
|
|
|
|
def convert_models_to_fp32(model): |
|
|
for p in model.parameters(): |
|
|
p.data = p.data.float() |
|
|
if p.grad: |
|
|
p.grad.data = p.grad.data.float() |
|
|
|
|
|
def str2bool(v): |
|
|
if isinstance(v, bool): |
|
|
return v |
|
|
if v.lower() in ('yes', 'true', 't', 'y', '1'): |
|
|
return True |
|
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
|
|
return False |
|
|
|
|
|
def refine_classname(class_names): |
|
|
for i, class_name in enumerate(class_names): |
|
|
class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ') |
|
|
return class_names |
|
|
|
|
|
def save_checkpoint(state, args, is_best=False, filename='checkpoint.pth.tar'): |
|
|
savefile = os.path.join(args.model_folder, filename) |
|
|
bestfile = os.path.join(args.model_folder, 'model_best.pth.tar') |
|
|
torch.save(state, savefile) |
|
|
if is_best: |
|
|
shutil.copyfile(savefile, bestfile) |
|
|
print ('saved best file') |
|
|
|
|
|
def assign_learning_rate(optimizer, new_lr): |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group["lr"] = new_lr |
|
|
|
|
|
def _warmup_lr(base_lr, warmup_length, step): |
|
|
return base_lr * (step + 1) / warmup_length |
|
|
|
|
|
|
|
|
def cosine_lr(optimizer, base_lr, warmup_length, steps): |
|
|
def _lr_adjuster(step): |
|
|
if step < warmup_length: |
|
|
lr = _warmup_lr(base_lr, warmup_length, step) |
|
|
else: |
|
|
e = step - warmup_length |
|
|
es = steps - warmup_length |
|
|
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr |
|
|
assign_learning_rate(optimizer, lr) |
|
|
return lr |
|
|
return _lr_adjuster |
|
|
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)): |
|
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
|
with torch.no_grad(): |
|
|
maxk = max(topk) |
|
|
batch_size = target.size(0) |
|
|
|
|
|
_, pred = output.topk(maxk, 1, True, True) |
|
|
pred = pred.t() |
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
|
|
|
res = [] |
|
|
for k in topk: |
|
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
|
return res |
|
|
|
|
|
|
|
|
class AverageMeter(object): |
|
|
"""Computes and stores the average and current value""" |
|
|
def __init__(self, name, fmt=':f'): |
|
|
self.name = name |
|
|
self.fmt = fmt |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.val = 0 |
|
|
self.avg = 0 |
|
|
self.sum = 0 |
|
|
self.count = 0 |
|
|
|
|
|
def update(self, val, n=1): |
|
|
self.val = val |
|
|
self.sum += val * n |
|
|
self.count += n |
|
|
self.avg = self.sum / self.count |
|
|
|
|
|
def __str__(self): |
|
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
|
|
|
class ProgressMeter(object): |
|
|
def __init__(self, num_batches, meters, prefix=""): |
|
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
|
|
self.meters = meters |
|
|
self.prefix = prefix |
|
|
|
|
|
def display(self, batch): |
|
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
|
|
entries += [str(meter) for meter in self.meters] |
|
|
print('\t'.join(entries)) |
|
|
|
|
|
def _get_batch_fmtstr(self, num_batches): |
|
|
num_digits = len(str(num_batches // 1)) |
|
|
fmt = '{:' + str(num_digits) + 'd}' |
|
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
|
|
|
|
|
|
|
|
def load_imagenet_folder2name(path): |
|
|
dict_imagenet_folder2name = {} |
|
|
with open(path) as f: |
|
|
line = f.readline() |
|
|
while line: |
|
|
split_name = line.strip().split() |
|
|
cat_name = split_name[2] |
|
|
id = split_name[0] |
|
|
dict_imagenet_folder2name[id] = cat_name |
|
|
line = f.readline() |
|
|
|
|
|
return dict_imagenet_folder2name |
|
|
|
|
|
|
|
|
|
|
|
def one_hot_embedding(labels, num_classes): |
|
|
"""Embedding labels to one-hot form. |
|
|
Args: |
|
|
labels: (LongTensor) class labels, sized [N,]. |
|
|
num_classes: (int) number of classes. |
|
|
Returns: |
|
|
(tensor) encoded labels, sized [N, #classes]. |
|
|
""" |
|
|
y = torch.eye(num_classes) |
|
|
return y[labels] |
|
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
preprocess224 = transforms.Compose([ |
|
|
transforms.Resize(256), |
|
|
transforms.CenterCrop(224), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
preprocess224_interpolate = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor() |
|
|
]) |
|
|
|
|
|
def load_train_dataset(args): |
|
|
if args.dataset == 'cifar100': |
|
|
return CIFAR100(args.root, transform=preprocess, download=True, train=True) |
|
|
elif args.dataset == 'cifar10': |
|
|
return CIFAR10(args.root, transform=preprocess, download=True, train=True) |
|
|
elif args.dataset == 'ImageNet': |
|
|
assert args.imagenet_root is not None |
|
|
print(f"Loading ImageNet from {args.imagenet_root}") |
|
|
return ImageFolder(os.path.join(args.imgnet_full, 'train'), transform=preprocess224) |
|
|
else: |
|
|
print(f"Train dataset {args.dataset} not implemented") |
|
|
raise NotImplementedError |
|
|
|
|
|
def load_val_datasets(args, val_dataset_names): |
|
|
val_dataset_list = [] |
|
|
for each in val_dataset_names: |
|
|
if each == 'cifar10': |
|
|
val_dataset_list.append(CIFAR10(args.root, transform=preprocess, |
|
|
download=True, train=False)) |
|
|
elif each == 'cifar100': |
|
|
val_dataset_list.append(CIFAR100(args.root, transform=preprocess, |
|
|
download=True, train=False)) |
|
|
elif each == 'Caltech101': |
|
|
val_dataset_list.append(Caltech101(args.root, target_type='category', transform=preprocess224, |
|
|
download=True)) |
|
|
elif each == 'PCAM': |
|
|
val_dataset_list.append(PCAM(args.root, split='test', transform=preprocess224, |
|
|
download=True)) |
|
|
elif each == 'STL10': |
|
|
val_dataset_list.append(STL10(args.root, split='test', |
|
|
transform=preprocess, download=True)) |
|
|
elif each == 'SUN397': |
|
|
val_dataset_list.append(SUN397(args.root, |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'StanfordCars': |
|
|
val_dataset_list.append(StanfordCars(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'Food101': |
|
|
val_dataset_list.append(Food101(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'oxfordpet': |
|
|
val_dataset_list.append(OxfordIIITPet(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'EuroSAT': |
|
|
val_dataset_list.append(EuroSAT(args.root, |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'Caltech256': |
|
|
val_dataset_list.append(Caltech256(args.root, transform=preprocess224, |
|
|
download=True)) |
|
|
elif each == 'flowers102': |
|
|
val_dataset_list.append(Flowers102(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'Country211': |
|
|
val_dataset_list.append(Country211(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'dtd': |
|
|
val_dataset_list.append(DTD(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'fgvc_aircraft': |
|
|
val_dataset_list.append(FGVCAircraft(args.root, split='test', |
|
|
transform=preprocess224, download=True)) |
|
|
elif each == 'hateful_memes': |
|
|
val_dataset_list.append(HatefulMemes(args.root, splits=['test_seen', 'test_unseen'], |
|
|
transform=preprocess224_interpolate)) |
|
|
elif each == 'ImageNet': |
|
|
|
|
|
val_dataset_list.append(ImageFolder( |
|
|
os.path.join(args.imgnet_full, 'val'), |
|
|
transform=preprocess224)) |
|
|
else: |
|
|
print(f"Val dataset {each} not implemented") |
|
|
raise NotImplementedError |
|
|
return val_dataset_list |
|
|
|
|
|
def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'): |
|
|
class_names = train_dataset.classes |
|
|
if args.dataset == 'ImageNet': |
|
|
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') |
|
|
new_class_names = [] |
|
|
for each in class_names: |
|
|
new_class_names.append(folder2name[each]) |
|
|
|
|
|
class_names = new_class_names |
|
|
|
|
|
class_names = refine_classname(class_names) |
|
|
texts_train = [template.format(label) for label in class_names] |
|
|
|
|
|
|
|
|
training_original_classnames = class_names |
|
|
|
|
|
return texts_train, training_original_classnames |
|
|
|
|
|
def get_text_prompts_val(val_dataset_list, val_dataset_name, template='This is a photo of a {}'): |
|
|
texts_list = [] |
|
|
for cnt, each in enumerate(val_dataset_list): |
|
|
if hasattr(each, 'clip_prompts'): |
|
|
texts_tmp = each.clip_prompts |
|
|
else: |
|
|
class_names = each.classes |
|
|
if val_dataset_name[cnt] == 'ImageNet': |
|
|
from utils import load_imagenet_folder2name |
|
|
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') |
|
|
new_class_names = [] |
|
|
for class_name in class_names: |
|
|
new_class_names.append(folder2name[class_name]) |
|
|
class_names = new_class_names |
|
|
|
|
|
class_names = refine_classname(class_names) |
|
|
texts_tmp = [template.format(label) for label in class_names] |
|
|
texts_list.append(texts_tmp) |
|
|
assert len(texts_list) == len(val_dataset_list) |
|
|
return texts_list |
|
|
|