import shutil import os import pickle import torch import numpy as np import torchvision.transforms as transforms from torchvision.datasets import * from typing import Any, Callable, Optional, Tuple from PIL import Image def reset_log_file(file_path): """Reset the log file at the start of the program.""" with open(file_path, 'w') as f: # Open the file in write mode to reset it f.write('') # Writing an empty string will clear the file def log_record(message, file_path): """Log messages to both console and file in append mode.""" print(message) # Print to console with open(file_path, 'a') as f: # Open the file in append mode f.write(message + '\n') # Append the message to the file with a newline def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() def str2bool(v): if isinstance(v, bool): return v if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False def refine_classname(class_names): for i, class_name in enumerate(class_names): class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ') return class_names def save_checkpoint(state, args, is_best=False, filename='checkpoint.pth.tar'): savefile = os.path.join(args.model_folder, filename) bestfile = os.path.join(args.model_folder, 'model_best.pth.tar') torch.save(state, savefile) if is_best: shutil.copyfile(savefile, bestfile) print ('saved best file') def assign_learning_rate(optimizer, new_lr): for param_group in optimizer.param_groups: param_group["lr"] = new_lr def _warmup_lr(base_lr, warmup_length, step): return base_lr * (step + 1) / warmup_length def cosine_lr(optimizer, base_lr, warmup_length, steps): def _lr_adjuster(step): if step < warmup_length: lr = _warmup_lr(base_lr, warmup_length, step) else: e = step - warmup_length es = steps - warmup_length lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr assign_learning_rate(optimizer, lr) return lr return _lr_adjuster def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self, name, fmt=':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' def load_imagenet_folder2name(path): dict_imagenet_folder2name = {} with open(path) as f: line = f.readline() while line: split_name = line.strip().split() cat_name = split_name[2] id = split_name[0] dict_imagenet_folder2name[id] = cat_name line = f.readline() # print(dict_imagenet_folder2name) return dict_imagenet_folder2name def one_hot_embedding(labels, num_classes): """Embedding labels to one-hot form. Args: labels: (LongTensor) class labels, sized [N,]. num_classes: (int) number of classes. Returns: (tensor) encoded labels, sized [N, #classes]. """ y = torch.eye(num_classes) return y[labels] preprocess = transforms.Compose([ transforms.ToTensor() ]) preprocess224 = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor() ]) preprocess224_interpolate = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) def load_train_dataset(args): if args.dataset == 'cifar100': return CIFAR100(args.root, transform=preprocess, download=True, train=True) elif args.dataset == 'cifar10': return CIFAR10(args.root, transform=preprocess, download=True, train=True) elif args.dataset == 'ImageNet': assert args.imagenet_root is not None print(f"Loading ImageNet from {args.imagenet_root}") return ImageFolder(os.path.join(args.imgnet_full, 'train'), transform=preprocess224) else: print(f"Train dataset {args.dataset} not implemented") raise NotImplementedError def load_val_datasets(args, val_dataset_names): val_dataset_list = [] for each in val_dataset_names: if each == 'cifar10': val_dataset_list.append(CIFAR10(args.root, transform=preprocess, download=True, train=False)) elif each == 'cifar100': val_dataset_list.append(CIFAR100(args.root, transform=preprocess, download=True, train=False)) elif each == 'Caltech101': val_dataset_list.append(Caltech101(args.root, target_type='category', transform=preprocess224, download=True)) elif each == 'PCAM': val_dataset_list.append(PCAM(args.root, split='test', transform=preprocess224, download=True)) elif each == 'STL10': val_dataset_list.append(STL10(args.root, split='test', transform=preprocess, download=True)) elif each == 'SUN397': val_dataset_list.append(SUN397(args.root, transform=preprocess224, download=True)) elif each == 'StanfordCars': val_dataset_list.append(StanfordCars(args.root, split='test', transform=preprocess224, download=True)) elif each == 'Food101': val_dataset_list.append(Food101(args.root, split='test', transform=preprocess224, download=True)) elif each == 'oxfordpet': val_dataset_list.append(OxfordIIITPet(args.root, split='test', transform=preprocess224, download=True)) elif each == 'EuroSAT': val_dataset_list.append(EuroSAT(args.root, transform=preprocess224, download=True)) elif each == 'Caltech256': val_dataset_list.append(Caltech256(args.root, transform=preprocess224, download=True)) elif each == 'flowers102': val_dataset_list.append(Flowers102(args.root, split='test', transform=preprocess224, download=True)) elif each == 'Country211': val_dataset_list.append(Country211(args.root, split='test', transform=preprocess224, download=True)) elif each == 'dtd': val_dataset_list.append(DTD(args.root, split='test', transform=preprocess224, download=True)) elif each == 'fgvc_aircraft': val_dataset_list.append(FGVCAircraft(args.root, split='test', transform=preprocess224, download=True)) elif each == 'hateful_memes': val_dataset_list.append(HatefulMemes(args.root, splits=['test_seen', 'test_unseen'], transform=preprocess224_interpolate)) elif each == 'ImageNet': # val_dataset_list.append(ImageFolder(os.path.join(args.imagenet_root, 'val'), transform=preprocess224)) val_dataset_list.append(ImageFolder( os.path.join(args.imgnet_full, 'val'), transform=preprocess224)) else: print(f"Val dataset {each} not implemented") raise NotImplementedError return val_dataset_list def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'): class_names = train_dataset.classes if args.dataset == 'ImageNet': folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') new_class_names = [] for each in class_names: new_class_names.append(folder2name[each]) class_names = new_class_names class_names = refine_classname(class_names) texts_train = [template.format(label) for label in class_names] ###### Save the original classnames for Text Prompt Tuning training_original_classnames = class_names return texts_train, training_original_classnames def get_text_prompts_val(val_dataset_list, val_dataset_name, template='This is a photo of a {}'): texts_list = [] for cnt, each in enumerate(val_dataset_list): if hasattr(each, 'clip_prompts'): texts_tmp = each.clip_prompts else: class_names = each.classes if val_dataset_name[cnt] == 'ImageNet': from utils import load_imagenet_folder2name folder2name = load_imagenet_folder2name('imagenet_classes_names.txt') new_class_names = [] for class_name in class_names: new_class_names.append(folder2name[class_name]) class_names = new_class_names class_names = refine_classname(class_names) texts_tmp = [template.format(label) for label in class_names] texts_list.append(texts_tmp) assert len(texts_list) == len(val_dataset_list) return texts_list