""" Builds upon: https://github.com/mr-eggplant/EATA Corresponding paper: https://arxiv.org/abs/2204.02610 """ from copy import deepcopy import os import math import logging import torch import torch.nn as nn import torch.nn.functional as F from utils.misc import ema_update_model from methods.base import TTAMethod from datasets.data_loading import get_source_loader from utils.registry import ADAPTATION_REGISTRY from utils.losses import Entropy logger = logging.getLogger(__name__) @ADAPTATION_REGISTRY.register() class EATA(TTAMethod): """EATA adapts a model by entropy minimization during testing. Once EATAed, a model adapts itself by updating on every forward. """ def __init__(self, cfg, model, num_classes): super().__init__(cfg, model, num_classes) self.num_samples_update_1 = 0 # number of samples after first filtering, exclude unreliable samples self.num_samples_update_2 = 0 # number of samples after second filtering, exclude both unreliable and redundant samples self.e_margin = cfg.EATA.MARGIN_E0 * math.log(num_classes) # hyper-parameter E_0 (Eqn. 3) self.d_margin = cfg.EATA.D_MARGIN # hyperparameter \epsilon for cosine similarity thresholding (Eqn. 5) self.current_model_probs = None # the moving average of probability vector (Eqn. 4) self.fisher_alpha = cfg.EATA.FISHER_ALPHA # trade-off \beta for two losses (Eqn. 8) # setup loss function self.softmax_entropy = Entropy() if self.fisher_alpha > 0.0 and self.cfg.SOURCE.NUM_SAMPLES > 0: # compute fisher informatrix batch_size_src = cfg.TEST.BATCH_SIZE if cfg.TEST.BATCH_SIZE > 1 else cfg.TEST.WINDOW_LENGTH _, fisher_loader = get_source_loader(dataset_name=cfg.CORRUPTION.DATASET, data_root_dir=cfg.DATA_DIR, batch_size=batch_size_src, train_split=False, num_samples=cfg.SOURCE.NUM_SAMPLES, # number of samples for ewc reg. percentage=cfg.SOURCE.PERCENTAGE, workers=min(cfg.SOURCE.NUM_WORKERS, os.cpu_count()), preprocess=model.model_preprocess) ewc_optimizer = torch.optim.SGD(self.params, 0.001) self.fishers = {} # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9) train_loss_fn = nn.CrossEntropyLoss().to(self.device) for iter_, batch in enumerate(fisher_loader, start=1): images = batch[0].to(self.device, non_blocking=True) outputs = self.model(images) _, targets = outputs.max(1) loss = train_loss_fn(outputs, targets) loss.backward() for name, param in model.named_parameters(): if param.grad is not None: if iter_ > 1: fisher = param.grad.data.clone().detach() ** 2 + self.fishers[name][0] else: fisher = param.grad.data.clone().detach() ** 2 if iter_ == len(fisher_loader): fisher = fisher / iter_ self.fishers.update({name: [fisher, param.data.clone().detach()]}) ewc_optimizer.zero_grad() logger.info("Finished computing the fisher matrices...") del ewc_optimizer else: logger.info("Not using EWC regularization. EATA decays to ETA!") self.fishers = None # note: reduce memory consumption by only saving normalization parameters self.src_model = deepcopy(self.model).cpu() for param in self.src_model.parameters(): param.detach_() # note: if the model is never reset, like for continual adaptation, # then skipping the state copy would save memory self.models = [self.src_model, self.model] self.model_states, self.optimizer_state = self.copy_model_and_optimizer() def loss_calculation(self, x): """Forward and adapt model on batch of data. Measure entropy of the model prediction, take gradients, and update params. """ imgs_test = x[0] outputs = self.model(imgs_test) entropys = self.softmax_entropy(outputs) # filter unreliable samples filter_ids_1 = torch.where(entropys < self.e_margin) ids1 = filter_ids_1 ids2 = torch.where(ids1[0] > -0.1) entropys = entropys[filter_ids_1] # filter redundant samples if self.current_model_probs is not None: cosine_similarities = F.cosine_similarity(self.current_model_probs.unsqueeze(dim=0), outputs[filter_ids_1].softmax(1), dim=1) filter_ids_2 = torch.where(torch.abs(cosine_similarities) < self.d_margin) entropys = entropys[filter_ids_2] updated_probs = update_model_probs(self.current_model_probs, outputs[filter_ids_1][filter_ids_2].softmax(1)) else: updated_probs = update_model_probs(self.current_model_probs, outputs[filter_ids_1].softmax(1)) coeff = 1 / (torch.exp(entropys.clone().detach() - self.e_margin)) # implementation version 1, compute loss, all samples backward (some unselected are masked) entropys = entropys.mul(coeff) # reweight entropy losses for diff. samples loss = entropys.mean(0) """ # implementation version 2, compute loss, forward all batch, forward and backward selected samples again. # if x[ids1][ids2].size(0) != 0: # loss = self.softmax_entropy(model(x[ids1][ids2])).mul(coeff).mean(0) # reweight entropy losses for diff. samples """ if self.fishers is not None: ewc_loss = 0 for name, param in self.model.named_parameters(): if name in self.fishers: ewc_loss += self.fisher_alpha * (self.fishers[name][0] * (param - self.fishers[name][1]) ** 2).sum() loss += ewc_loss self.num_samples_update_1 += filter_ids_1[0].size(0) self.num_samples_update_2 += entropys.size(0) self.current_model_probs = updated_probs perform_update = len(entropys) != 0 return outputs, loss, perform_update @torch.enable_grad() def forward_and_adapt(self, x): if self.mixed_precision and self.device == "cuda": with torch.cuda.amp.autocast(): outputs, loss, perform_update = self.loss_calculation(x) # update model only if not all instances have been filtered if perform_update: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() else: outputs, loss, perform_update = self.loss_calculation(x) # update model only if not all instances have been filtered if perform_update: loss.backward() self.optimizer.step() self.optimizer.zero_grad() return outputs def reset(self): if self.model_states is None or self.optimizer_state is None: raise Exception("cannot reset without saved model/optimizer state") self.load_model_and_optimizer() self.current_model_probs = None def collect_params(self): """Collect the affine scale + shift parameters from batch norms. Walk the model's modules and collect all batch normalization parameters. Return the parameters and their names. Note: other choices of parameterization are possible! """ params = [] names = [] for nm, m in self.model.named_modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): for np, p in m.named_parameters(): if np in ['weight', 'bias']: # weight is scale, bias is shift params.append(p) names.append(f"{nm}.{np}") return params, names def configure_model(self): """Configure model for use with eata.""" # train mode, because eata optimizes the model to minimize entropy # self.model.train() self.model.eval() # eval mode to avoid stochastic depth in swin. test-time normalization is still applied # disable grad, to (re-)enable only what eata updates self.model.requires_grad_(False) # configure norm for eata updates: enable grad + force batch statisics for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.requires_grad_(True) # force use of batch stats in train and eval modes m.track_running_stats = False m.running_mean = None m.running_var = None elif isinstance(m, nn.BatchNorm1d): m.train() # always forcing train mode in bn1d will cause problems for single sample tta m.requires_grad_(True) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): m.requires_grad_(True) def update_model_probs(current_model_probs, new_probs): if current_model_probs is None: if new_probs.size(0) == 0: return None else: with torch.no_grad(): return new_probs.mean(0) else: if new_probs.size(0) == 0: with torch.no_grad(): return current_model_probs else: with torch.no_grad(): return 0.9 * current_model_probs + (1 - 0.9) * new_probs.mean(0)