Robust_vlm / utils.py
Yaning1001's picture
Add files using upload-large-folder tool
64470b3 verified
import shutil
import os
import pickle
import torch
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import *
from typing import Any, Callable, Optional, Tuple
from PIL import Image
def reset_log_file(file_path):
"""Reset the log file at the start of the program."""
with open(file_path, 'w') as f: # Open the file in write mode to reset it
f.write('') # Writing an empty string will clear the file
def log_record(message, file_path):
"""Log messages to both console and file in append mode."""
print(message) # Print to console
with open(file_path, 'a') as f: # Open the file in append mode
f.write(message + '\n') # Append the message to the file with a newline
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
if p.grad:
p.grad.data = p.grad.data.float()
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
def refine_classname(class_names):
for i, class_name in enumerate(class_names):
class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ')
return class_names
def save_checkpoint(state, args, is_best=False, filename='checkpoint.pth.tar'):
savefile = os.path.join(args.model_folder, filename)
bestfile = os.path.join(args.model_folder, 'model_best.pth.tar')
torch.save(state, savefile)
if is_best:
shutil.copyfile(savefile, bestfile)
print ('saved best file')
def assign_learning_rate(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lr, warmup_length, steps):
def _lr_adjuster(step):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(optimizer, lr)
return lr
return _lr_adjuster
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def load_imagenet_folder2name(path):
dict_imagenet_folder2name = {}
with open(path) as f:
line = f.readline()
while line:
split_name = line.strip().split()
cat_name = split_name[2]
id = split_name[0]
dict_imagenet_folder2name[id] = cat_name
line = f.readline()
# print(dict_imagenet_folder2name)
return dict_imagenet_folder2name
def one_hot_embedding(labels, num_classes):
"""Embedding labels to one-hot form.
Args:
labels: (LongTensor) class labels, sized [N,].
num_classes: (int) number of classes.
Returns:
(tensor) encoded labels, sized [N, #classes].
"""
y = torch.eye(num_classes)
return y[labels]
preprocess = transforms.Compose([
transforms.ToTensor()
])
preprocess224 = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
preprocess224_interpolate = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def load_train_dataset(args):
if args.dataset == 'cifar100':
return CIFAR100(args.root, transform=preprocess, download=True, train=True)
elif args.dataset == 'cifar10':
return CIFAR10(args.root, transform=preprocess, download=True, train=True)
elif args.dataset == 'ImageNet':
assert args.imagenet_root is not None
print(f"Loading ImageNet from {args.imagenet_root}")
return ImageFolder(os.path.join(args.imgnet_full, 'train'), transform=preprocess224)
else:
print(f"Train dataset {args.dataset} not implemented")
raise NotImplementedError
def load_val_datasets(args, val_dataset_names):
val_dataset_list = []
for each in val_dataset_names:
if each == 'cifar10':
val_dataset_list.append(CIFAR10(args.root, transform=preprocess,
download=True, train=False))
elif each == 'cifar100':
val_dataset_list.append(CIFAR100(args.root, transform=preprocess,
download=True, train=False))
elif each == 'Caltech101':
val_dataset_list.append(Caltech101(args.root, target_type='category', transform=preprocess224,
download=True))
elif each == 'PCAM':
val_dataset_list.append(PCAM(args.root, split='test', transform=preprocess224,
download=True))
elif each == 'STL10':
val_dataset_list.append(STL10(args.root, split='test',
transform=preprocess, download=True))
elif each == 'SUN397':
val_dataset_list.append(SUN397(args.root,
transform=preprocess224, download=True))
elif each == 'StanfordCars':
val_dataset_list.append(StanfordCars(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'Food101':
val_dataset_list.append(Food101(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'oxfordpet':
val_dataset_list.append(OxfordIIITPet(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'EuroSAT':
val_dataset_list.append(EuroSAT(args.root,
transform=preprocess224, download=True))
elif each == 'Caltech256':
val_dataset_list.append(Caltech256(args.root, transform=preprocess224,
download=True))
elif each == 'flowers102':
val_dataset_list.append(Flowers102(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'Country211':
val_dataset_list.append(Country211(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'dtd':
val_dataset_list.append(DTD(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'fgvc_aircraft':
val_dataset_list.append(FGVCAircraft(args.root, split='test',
transform=preprocess224, download=True))
elif each == 'hateful_memes':
val_dataset_list.append(HatefulMemes(args.root, splits=['test_seen', 'test_unseen'],
transform=preprocess224_interpolate))
elif each == 'ImageNet':
# val_dataset_list.append(ImageFolder(os.path.join(args.imagenet_root, 'val'), transform=preprocess224))
val_dataset_list.append(ImageFolder(
os.path.join(args.imgnet_full, 'val'),
transform=preprocess224))
else:
print(f"Val dataset {each} not implemented")
raise NotImplementedError
return val_dataset_list
def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'):
class_names = train_dataset.classes
if args.dataset == 'ImageNet':
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt')
new_class_names = []
for each in class_names:
new_class_names.append(folder2name[each])
class_names = new_class_names
class_names = refine_classname(class_names)
texts_train = [template.format(label) for label in class_names]
###### Save the original classnames for Text Prompt Tuning
training_original_classnames = class_names
return texts_train, training_original_classnames
def get_text_prompts_val(val_dataset_list, val_dataset_name, template='This is a photo of a {}'):
texts_list = []
for cnt, each in enumerate(val_dataset_list):
if hasattr(each, 'clip_prompts'):
texts_tmp = each.clip_prompts
else:
class_names = each.classes
if val_dataset_name[cnt] == 'ImageNet':
from utils import load_imagenet_folder2name
folder2name = load_imagenet_folder2name('imagenet_classes_names.txt')
new_class_names = []
for class_name in class_names:
new_class_names.append(folder2name[class_name])
class_names = new_class_names
class_names = refine_classname(class_names)
texts_tmp = [template.format(label) for label in class_names]
texts_list.append(texts_tmp)
assert len(texts_list) == len(val_dataset_list)
return texts_list