File size: 9,964 Bytes
02ba886 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | """
Builds upon: https://github.com/mr-eggplant/EATA
Corresponding paper: https://arxiv.org/abs/2204.02610
"""
from copy import deepcopy
import os
import math
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.misc import ema_update_model
from methods.base import TTAMethod
from datasets.data_loading import get_source_loader
from utils.registry import ADAPTATION_REGISTRY
from utils.losses import Entropy
logger = logging.getLogger(__name__)
@ADAPTATION_REGISTRY.register()
class EATA(TTAMethod):
"""EATA adapts a model by entropy minimization during testing.
Once EATAed, a model adapts itself by updating on every forward.
"""
def __init__(self, cfg, model, num_classes):
super().__init__(cfg, model, num_classes)
self.num_samples_update_1 = 0 # number of samples after first filtering, exclude unreliable samples
self.num_samples_update_2 = 0 # number of samples after second filtering, exclude both unreliable and redundant samples
self.e_margin = cfg.EATA.MARGIN_E0 * math.log(num_classes) # hyper-parameter E_0 (Eqn. 3)
self.d_margin = cfg.EATA.D_MARGIN # hyperparameter \epsilon for cosine similarity thresholding (Eqn. 5)
self.current_model_probs = None # the moving average of probability vector (Eqn. 4)
self.fisher_alpha = cfg.EATA.FISHER_ALPHA # trade-off \beta for two losses (Eqn. 8)
# setup loss function
self.softmax_entropy = Entropy()
if self.fisher_alpha > 0.0 and self.cfg.SOURCE.NUM_SAMPLES > 0:
# compute fisher informatrix
batch_size_src = cfg.TEST.BATCH_SIZE if cfg.TEST.BATCH_SIZE > 1 else cfg.TEST.WINDOW_LENGTH
_, fisher_loader = get_source_loader(dataset_name=cfg.CORRUPTION.DATASET,
data_root_dir=cfg.DATA_DIR,
batch_size=batch_size_src,
train_split=False,
num_samples=cfg.SOURCE.NUM_SAMPLES, # number of samples for ewc reg.
percentage=cfg.SOURCE.PERCENTAGE,
workers=min(cfg.SOURCE.NUM_WORKERS, os.cpu_count()),
preprocess=model.model_preprocess)
ewc_optimizer = torch.optim.SGD(self.params, 0.001)
self.fishers = {} # fisher regularizer items for anti-forgetting, need to be calculated pre model adaptation (Eqn. 9)
train_loss_fn = nn.CrossEntropyLoss().to(self.device)
for iter_, batch in enumerate(fisher_loader, start=1):
images = batch[0].to(self.device, non_blocking=True)
outputs = self.model(images)
_, targets = outputs.max(1)
loss = train_loss_fn(outputs, targets)
loss.backward()
for name, param in model.named_parameters():
if param.grad is not None:
if iter_ > 1:
fisher = param.grad.data.clone().detach() ** 2 + self.fishers[name][0]
else:
fisher = param.grad.data.clone().detach() ** 2
if iter_ == len(fisher_loader):
fisher = fisher / iter_
self.fishers.update({name: [fisher, param.data.clone().detach()]})
ewc_optimizer.zero_grad()
logger.info("Finished computing the fisher matrices...")
del ewc_optimizer
else:
logger.info("Not using EWC regularization. EATA decays to ETA!")
self.fishers = None
# note: reduce memory consumption by only saving normalization parameters
self.src_model = deepcopy(self.model).cpu()
for param in self.src_model.parameters():
param.detach_()
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.models = [self.src_model, self.model]
self.model_states, self.optimizer_state = self.copy_model_and_optimizer()
def loss_calculation(self, x):
"""Forward and adapt model on batch of data.
Measure entropy of the model prediction, take gradients, and update params.
"""
imgs_test = x[0]
outputs = self.model(imgs_test)
entropys = self.softmax_entropy(outputs)
# filter unreliable samples
filter_ids_1 = torch.where(entropys < self.e_margin)
ids1 = filter_ids_1
ids2 = torch.where(ids1[0] > -0.1)
entropys = entropys[filter_ids_1]
# filter redundant samples
if self.current_model_probs is not None:
cosine_similarities = F.cosine_similarity(self.current_model_probs.unsqueeze(dim=0), outputs[filter_ids_1].softmax(1), dim=1)
filter_ids_2 = torch.where(torch.abs(cosine_similarities) < self.d_margin)
entropys = entropys[filter_ids_2]
updated_probs = update_model_probs(self.current_model_probs, outputs[filter_ids_1][filter_ids_2].softmax(1))
else:
updated_probs = update_model_probs(self.current_model_probs, outputs[filter_ids_1].softmax(1))
coeff = 1 / (torch.exp(entropys.clone().detach() - self.e_margin))
# implementation version 1, compute loss, all samples backward (some unselected are masked)
entropys = entropys.mul(coeff) # reweight entropy losses for diff. samples
loss = entropys.mean(0)
"""
# implementation version 2, compute loss, forward all batch, forward and backward selected samples again.
# if x[ids1][ids2].size(0) != 0:
# loss = self.softmax_entropy(model(x[ids1][ids2])).mul(coeff).mean(0) # reweight entropy losses for diff. samples
"""
if self.fishers is not None:
ewc_loss = 0
for name, param in self.model.named_parameters():
if name in self.fishers:
ewc_loss += self.fisher_alpha * (self.fishers[name][0] * (param - self.fishers[name][1]) ** 2).sum()
loss += ewc_loss
self.num_samples_update_1 += filter_ids_1[0].size(0)
self.num_samples_update_2 += entropys.size(0)
self.current_model_probs = updated_probs
perform_update = len(entropys) != 0
return outputs, loss, perform_update
@torch.enable_grad()
def forward_and_adapt(self, x):
if self.mixed_precision and self.device == "cuda":
with torch.cuda.amp.autocast():
outputs, loss, perform_update = self.loss_calculation(x)
# update model only if not all instances have been filtered
if perform_update:
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
else:
outputs, loss, perform_update = self.loss_calculation(x)
# update model only if not all instances have been filtered
if perform_update:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return outputs
def reset(self):
if self.model_states is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
self.load_model_and_optimizer()
self.current_model_probs = None
def collect_params(self):
"""Collect the affine scale + shift parameters from batch norms.
Walk the model's modules and collect all batch normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in self.model.named_modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias']: # weight is scale, bias is shift
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def configure_model(self):
"""Configure model for use with eata."""
# train mode, because eata optimizes the model to minimize entropy
# self.model.train()
self.model.eval() # eval mode to avoid stochastic depth in swin. test-time normalization is still applied
# disable grad, to (re-)enable only what eata updates
self.model.requires_grad_(False)
# configure norm for eata updates: enable grad + force batch statisics
for m in self.model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
elif isinstance(m, nn.BatchNorm1d):
m.train() # always forcing train mode in bn1d will cause problems for single sample tta
m.requires_grad_(True)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
m.requires_grad_(True)
def update_model_probs(current_model_probs, new_probs):
if current_model_probs is None:
if new_probs.size(0) == 0:
return None
else:
with torch.no_grad():
return new_probs.mean(0)
else:
if new_probs.size(0) == 0:
with torch.no_grad():
return current_model_probs
else:
with torch.no_grad():
return 0.9 * current_model_probs + (1 - 0.9) * new_probs.mean(0)
|