import random import os import argparse from datetime import datetime import torch from tqdm import tqdm import clip from utils import * from torch import nn import logging import bisect from sortedcontainers import SortedList import numpy as np import torch.backends.cudnn as cudnn import torch.nn.functional as F class ConfidenceChecker: def __init__(self, gamma=0): """ Initialize the unconfident detector :param gamma: Floating-point numbers, lower bound percentiles (e.g. 0.05 for 5%), the smallest 5% are considered unconfident samples """ self.gamma = gamma self.sorted_values = SortedList() def add_value(self, value): """ Add a new value to the sorted list :param value: floating-point number, a new value """ self.sorted_values.add(value) def is_last_element_unconfident(self, last_value): """ Detect if the last element is not confident :param last_value: floating-point number, the value of the last element :return: Boolean, whether the last element deviates significantly from the primary data, is an overly unconfident sample """ if len(self.sorted_values) == 0 or self.gamma == 0: return False # If there are no remaining elements or if the confidence level is 0, it cannot be judged # Calculate the lower bound percentile lower_bound = self.sorted_values[int(len(self.sorted_values) * self.gamma)] # Determines if the last element is smaller than the nether return last_value < lower_bound def setup_seeds(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) cudnn.benchmark = False cudnn.deterministic = True class DOTA(nn.Module): def __init__(self, cfg, input_shape, num_classes, clip_weights, streaming_update_Sigma=True): super(DOTA, self).__init__() self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.input_shape = input_shape self.num_classes = num_classes self.streaming_update_Sigma = streaming_update_Sigma self.epsilon = cfg['epsilon'] self.mu = clip_weights.T.to(self.device) # initialize mu with clip_weights self.c = torch.ones(num_classes, dtype=torch.float32).to(self.device) self.Sigma = cfg['sigma'] * torch.eye(input_shape, dtype=torch.float32).repeat(num_classes, 1, 1).to(self.device) self.overall_Sigma = torch.mean(self.Sigma, dim=0) self.Lambda = torch.pinverse(self.overall_Sigma.double()).to(self.device).half() # Update the covariance and the mean for the corresponding category def fit(self, x, y): x = x.to(self.device) y = y.to(self.device) # y is now a probability distribution (soft labels) with torch.no_grad(): sum_weights = torch.sum(y, dim=0) weighted_x = torch.matmul(y.T, x) new_mu = (weighted_x + self.c.unsqueeze(1) * self.mu) / (sum_weights.unsqueeze(1) + self.c.unsqueeze(1)) new_c = self.c + sum_weights # Update the covariance matrix for each category if self.streaming_update_Sigma: x_minus_mu = x.unsqueeze(1) - self.mu.unsqueeze(0) # Shape: (batch_size, num_classes, input_shape) weighted_x_minus_mu = y.unsqueeze(2) * x_minus_mu # Shape: (batch_size, num_classes, input_shape) delta = torch.einsum('bji,bjk->jik', weighted_x_minus_mu, x_minus_mu) # Shape: (num_classes, input_shape, input_shape) self.Sigma = (self.c[:, None, None] * self.Sigma + delta) / (self.c[:, None, None] + sum_weights[:, None, None]) # Update the total covariance matrix, mean matrix, and count sections self.overall_Sigma = torch.mean(self.Sigma, dim=0) self.mu = new_mu self.c = new_c # Update the inverse matrix to include a small identity matrix when calculating the inverse matrix to ensure that it is full rank def update(self): self.Lambda = torch.inverse( (1 - self.epsilon) * self.overall_Sigma + self.epsilon * torch.eye(self.input_shape).to( self.device)).half() # Calculate the results of Dota predictions def predict(self, X): X = X.to(self.device) with torch.no_grad(): Lambda = self.Lambda M = self.mu.transpose(1, 0).half() W = torch.matmul(Lambda, M) c = 0.5 * torch.sum(M * W, dim=0) scores = torch.matmul(X, W) - c return scores def get_arguments(): """Get arguments of the test-time adaptation.""" parser = argparse.ArgumentParser() parser.add_argument('--config', dest='config', default='configs', help='settings of TDA on specific dataset in yaml format.') parser.add_argument('--datasets', dest='datasets', default='I', type=str, help="Datasets to process, separated by a slash (/). Example: I/A/V/R/S") parser.add_argument('--data-root', dest='data_root', type=str, default='./data/', help='Path to the datasets directory. Default is ./dataset/') parser.add_argument('--backbone', dest='backbone', type=str, default='ViT-B/16', choices=['RN50', 'ViT-B/16'], help='CLIP model backbone to use: RN50 or ViT-B/16.') parser.add_argument('--log-path', dest='log_path', type=str, default='./log', help='Path to the log file.') args = parser.parse_args() return args def run_test_dota(params, loader, clip_model, clip_weights, dota_model, logger): recent_sample_count = 1000 unconfident_num = 0 fusion_accuracies, dota_accuracies, clip_accuracies = [], [], [] # It is used to store the maximum value of each sample feature and sort it to determine whether it is a sample with high uncertainty # The higher the value, the lower the uncertainty we assume here entropy_list = [] # Initialize the unconfident detector, gamma is the proportion of the true label obtained checker = ConfidenceChecker(params['gamma']) with torch.no_grad(): for i, (images, target) in enumerate(tqdm(loader, desc='Processed test images: ')): # When data augmentation is used, the top 10% of enhanced images are selected to train the model image_features, clip_logits, loss, prob_map, pred = get_clip_logits_aug(images, clip_model, clip_weights) pred, target, prop_entropy = torch.tensor(pred).cuda(), target.cuda(), get_entropy(loss, clip_weights) dota_logits = dota_model.predict(image_features.mean(0).unsqueeze(0)) # Put the maximum value in the prop_entropy into the checker list after performing the softmax operation entropy_list.append(prop_entropy) softmax_output = F.softmax(clip_logits[0], dim=-1) max_logit = torch.max(softmax_output) checker.add_value(max_logit) # Choose a smaller weight, so that model relies more on the original clip initially dota_weights = torch.clamp(params['rho'] * dota_model.c.mean() / image_features.size(0), max=params['eta']) # Clip and Dota prediction weights are added to form the final prediction final_logits = clip_logits + dota_weights*dota_logits # Calculate the prediction accuracy of mixed weights, dota weights, clip weights and add them to the list fusion_acc, dota_acc, clip_acc = cls_acc(final_logits, target), cls_acc(dota_logits, target), cls_acc(clip_logits, target) fusion_accuracies.append(fusion_acc) dota_accuracies.append(dota_acc) clip_accuracies.append(clip_acc) # Determine whether the sample is unconfident unconfident = checker.is_last_element_unconfident(max_logit) if unconfident: # If it is unconfident, it is fitted with a real label unconfident_num = unconfident_num+1 one_hot_target = torch.nn.functional.one_hot(target, num_classes=prob_map.shape[1]).repeat_interleave(prob_map.shape[0], dim=0).half() dota_model.fit(image_features, one_hot_target) # For samples that did not use real labels, we used the weights predicted by the clip as the updated weights for the corresponding categories else : dota_model.fit(image_features, prob_map) # Update the inverse matrix dota_model.update() # Print the information if (i + 1) % recent_sample_count == 0: recent_fusion_accuracy = sum(fusion_accuracies[-recent_sample_count:]) / recent_sample_count recent_dota_accuracy = sum(dota_accuracies[-recent_sample_count:]) / recent_sample_count recent_clip_accuracy = sum(clip_accuracies[-recent_sample_count:]) / recent_sample_count logger.info( "Last {} samples' accuracies - Fusion: {:.2f}%, DOTA: {:.2f}%, CLIP: {:.2f}% | " "Overall accuracies - Fusion: {:.2f}%, DOTA: {:.2f}%, CLIP: {:.2f}%, unconfident sample number: {:.2f}".format( recent_sample_count, recent_fusion_accuracy, recent_dota_accuracy, recent_clip_accuracy, sum(fusion_accuracies) / len(fusion_accuracies), sum(dota_accuracies) / len(dota_accuracies), sum(clip_accuracies) / len(clip_accuracies), unconfident_num ) ) return { 'overall_fusion_accuracy': sum(fusion_accuracies) / len(fusion_accuracies), 'overall_dota_accuracy': sum(dota_accuracies) / len(dota_accuracies), 'overall_clip_accuracy': sum(clip_accuracies) / len(clip_accuracies), 'recent_fusion_accuracy': sum(fusion_accuracies[-recent_sample_count:]) / min(recent_sample_count, len(fusion_accuracies)), 'recent_dota_accuracy': sum(dota_accuracies[-recent_sample_count:]) / min(recent_sample_count, len(dota_accuracies)), 'recent_clip_accuracy': sum(clip_accuracies[-recent_sample_count:]) / min(recent_sample_count, len(clip_accuracies)), 'unconfident sample number': unconfident_num, } def main(): args = get_arguments() config_path = args.config clip_model, preprocess = clip.load(args.backbone) clip_model.eval() datasets = args.datasets.split('/') for dataset_name in datasets: # Set random seed setup_seeds(1) # Prepare logs and other content to facilitate printout information date = datetime.now().strftime("%b%d_%H-%M-%S") backbone_safe = args.backbone.replace('/', '_') group_name = f"{backbone_safe}_{dataset_name}_{date}" logging.basicConfig(filename=os.path.join(args.log_path, group_name), level=logging.INFO, format='%(asctime)s %(message)s') logger = logging.getLogger() logger.info(f"Processing {dataset_name} dataset.") # Obtain the hyperparameter information of the dataset cfg = get_config_file(config_path, dataset_name) logger.info("\nRunning dataset configurations:") logger.info(cfg) test_loader, classnames, template = build_test_data_loader(dataset_name, args.data_root, preprocess) clip_weights = clip_classifier(classnames, template, clip_model) dota_model = DOTA(cfg, input_shape=clip_weights.shape[0], num_classes=clip_weights.shape[1], clip_weights=clip_weights.clone()) dota_model.eval() acc = run_test_dota(cfg, test_loader, clip_model, clip_weights, dota_model, logger) logger.info(acc) if __name__ == "__main__": main()