import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.weight_norm import WeightNorm import math from copy import deepcopy from methods.base import TTAMethod from augmentations.transforms_cotta import get_tta_transforms from utils.registry import ADAPTATION_REGISTRY from utils.losses import Entropy, SymmetricCrossEntropy, SoftLikelihoodRatio from utils.misc import ema_update_model @torch.no_grad() def update_model_probs(x_ema, x, momentum=0.9): return momentum * x_ema + (1 - momentum) * x @ADAPTATION_REGISTRY.register() class ROID(TTAMethod): def __init__(self, cfg, model, num_classes): super().__init__(cfg, model, num_classes) self.use_weighting = cfg.ROID.USE_WEIGHTING self.use_prior_correction = cfg.ROID.USE_PRIOR_CORRECTION self.use_consistency = cfg.ROID.USE_CONSISTENCY self.momentum_src = cfg.ROID.MOMENTUM_SRC self.momentum_probs = cfg.ROID.MOMENTUM_PROBS self.temperature = cfg.ROID.TEMPERATURE self.batch_size = cfg.TEST.BATCH_SIZE self.class_probs_ema = 1 / self.num_classes * torch.ones(self.num_classes).to(self.device) self.e_margin = cfg.EATA.MARGIN_E0 * math.log(num_classes) # hyper-parameter E_0 (Eqn. 3) self.tta_transform = get_tta_transforms(self.img_size, padding_mode="reflect", cotta_augs=False) # setup loss functions self.slr = SoftLikelihoodRatio() self.symmetric_cross_entropy = SymmetricCrossEntropy() self.softmax_entropy = Entropy() # not used as loss # note: reduce memory consumption by only saving normalization parameters self.src_model = deepcopy(self.model).cpu() for param in self.src_model.parameters(): param.detach_() # note: if the model is never reset, like for continual adaptation, # then skipping the state copy would save memory self.models = [self.src_model, self.model] self.model_states, self.optimizer_state = self.copy_model_and_optimizer() def loss_calculation(self, x): imgs_test = x[0] outputs = self.model(imgs_test) perform_update = True if self.use_weighting: with torch.no_grad(): # calculate diversity based weight weights_div = 1 - F.cosine_similarity(self.class_probs_ema.unsqueeze(dim=0), outputs.softmax(1), dim=1) weights_div = (weights_div - weights_div.min()) / (weights_div.max() - weights_div.min()) mask = weights_div < weights_div.mean() # calculate certainty based weight weights_cert = - self.softmax_entropy(logits=outputs) weights_cert = (weights_cert - weights_cert.min()) / (weights_cert.max() - weights_cert.min()) if self.cfg.MODEL.ARCH == "Standard_VITB": mask &= (-weights_cert >= self.e_margin) # calculate the final weights weights = torch.exp(weights_div * weights_cert / self.temperature) weights[mask] = 0. perform_update = sum(weights) > 0 self.class_probs_ema = update_model_probs(x_ema=self.class_probs_ema, x=outputs.softmax(1).mean(0), momentum=self.momentum_probs) # calculate the soft likelihood ratio loss if perform_update: if self.cfg.MODEL.ARCH == "Standard_VITB": loss_out = self.softmax_entropy(logits=outputs) else: loss_out = self.slr(logits=outputs) # weight the loss if self.use_weighting: loss_out = loss_out * weights loss_out = loss_out[~mask] loss = loss_out.sum() / self.batch_size # calculate the consistency loss if self.use_consistency: outputs_aug = self.model(self.tta_transform(imgs_test[~mask])) loss += (self.symmetric_cross_entropy(x=outputs_aug, x_ema=outputs[~mask]) * weights[~mask]).sum() / self.batch_size return outputs, loss if perform_update else torch.Tensor([0.]), perform_update @torch.enable_grad() def forward_and_adapt(self, x): if self.mixed_precision and self.device == "cuda": with torch.cuda.amp.autocast(): outputs, loss = self.loss_calculation(x) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() self.optimizer.zero_grad() else: outputs, loss, perform_update = self.loss_calculation(x) if perform_update: loss.backward() self.optimizer.step() self.optimizer.zero_grad() if perform_update: self.model = ema_update_model( model_to_update=self.model, model_to_merge=self.src_model, momentum=self.momentum_src, device=self.device ) with torch.no_grad(): if self.use_prior_correction: prior = outputs.softmax(1).mean(0) smooth = max(1 / outputs.shape[0], 1 / outputs.shape[1]) / torch.max(prior) smoothed_prior = (prior + smooth) / (1 + smooth * outputs.shape[1]) outputs *= smoothed_prior return {'output': outputs, 'loss' : loss.item()} def reset(self): if self.model_states is None or self.optimizer_state is None: raise Exception("cannot reset without saved model/optimizer state") self.load_model_and_optimizer() self.class_probs_ema = 1 / self.num_classes * torch.ones(self.num_classes).to(self.device) def collect_params(self): """Collect the affine scale + shift parameters from normalization layers. Walk the model's modules and collect all normalization parameters. Return the parameters and their names. Note: other choices of parameterization are possible! """ params = [] names = [] for nm, m in self.model.named_modules(): if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): for np, p in m.named_parameters(): if np in ['weight', 'bias'] and p.requires_grad: params.append(p) names.append(f"{nm}.{np}") return params, names def configure_model(self): """Configure model.""" self.model.eval() # eval mode to avoid stochastic depth in swin. test-time normalization is still applied self.model.requires_grad_(False) # disable grad, to (re-)enable only necessary parts # re-enable gradient for normalization layers for m in self.model.modules(): if isinstance(m, nn.BatchNorm2d): m.requires_grad_(True) # force use of batch stats in train and eval modes m.track_running_stats = False m.running_mean = None m.running_var = None elif isinstance(m, nn.BatchNorm1d): m.train() # always forcing train mode in bn1d will cause problems for single sample tta m.requires_grad_(True) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)): m.requires_grad_(True)