File size: 7,384 Bytes
02ba886 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.weight_norm import WeightNorm
import math
from copy import deepcopy
from methods.base import TTAMethod
from augmentations.transforms_cotta import get_tta_transforms
from utils.registry import ADAPTATION_REGISTRY
from utils.losses import Entropy, SymmetricCrossEntropy, SoftLikelihoodRatio
from utils.misc import ema_update_model
@torch.no_grad()
def update_model_probs(x_ema, x, momentum=0.9):
return momentum * x_ema + (1 - momentum) * x
@ADAPTATION_REGISTRY.register()
class ROID(TTAMethod):
def __init__(self, cfg, model, num_classes):
super().__init__(cfg, model, num_classes)
self.use_weighting = cfg.ROID.USE_WEIGHTING
self.use_prior_correction = cfg.ROID.USE_PRIOR_CORRECTION
self.use_consistency = cfg.ROID.USE_CONSISTENCY
self.momentum_src = cfg.ROID.MOMENTUM_SRC
self.momentum_probs = cfg.ROID.MOMENTUM_PROBS
self.temperature = cfg.ROID.TEMPERATURE
self.batch_size = cfg.TEST.BATCH_SIZE
self.class_probs_ema = 1 / self.num_classes * torch.ones(self.num_classes).to(self.device)
self.e_margin = cfg.EATA.MARGIN_E0 * math.log(num_classes) # hyper-parameter E_0 (Eqn. 3)
self.tta_transform = get_tta_transforms(self.img_size, padding_mode="reflect", cotta_augs=False)
# setup loss functions
self.slr = SoftLikelihoodRatio()
self.symmetric_cross_entropy = SymmetricCrossEntropy()
self.softmax_entropy = Entropy() # not used as loss
# note: reduce memory consumption by only saving normalization parameters
self.src_model = deepcopy(self.model).cpu()
for param in self.src_model.parameters():
param.detach_()
# note: if the model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.models = [self.src_model, self.model]
self.model_states, self.optimizer_state = self.copy_model_and_optimizer()
def loss_calculation(self, x):
imgs_test = x[0]
outputs = self.model(imgs_test)
perform_update = True
if self.use_weighting:
with torch.no_grad():
# calculate diversity based weight
weights_div = 1 - F.cosine_similarity(self.class_probs_ema.unsqueeze(dim=0), outputs.softmax(1), dim=1)
weights_div = (weights_div - weights_div.min()) / (weights_div.max() - weights_div.min())
mask = weights_div < weights_div.mean()
# calculate certainty based weight
weights_cert = - self.softmax_entropy(logits=outputs)
weights_cert = (weights_cert - weights_cert.min()) / (weights_cert.max() - weights_cert.min())
if self.cfg.MODEL.ARCH == "Standard_VITB":
mask &= (-weights_cert >= self.e_margin)
# calculate the final weights
weights = torch.exp(weights_div * weights_cert / self.temperature)
weights[mask] = 0.
perform_update = sum(weights) > 0
self.class_probs_ema = update_model_probs(x_ema=self.class_probs_ema, x=outputs.softmax(1).mean(0), momentum=self.momentum_probs)
# calculate the soft likelihood ratio loss
if perform_update:
if self.cfg.MODEL.ARCH == "Standard_VITB":
loss_out = self.softmax_entropy(logits=outputs)
else:
loss_out = self.slr(logits=outputs)
# weight the loss
if self.use_weighting:
loss_out = loss_out * weights
loss_out = loss_out[~mask]
loss = loss_out.sum() / self.batch_size
# calculate the consistency loss
if self.use_consistency:
outputs_aug = self.model(self.tta_transform(imgs_test[~mask]))
loss += (self.symmetric_cross_entropy(x=outputs_aug, x_ema=outputs[~mask]) * weights[~mask]).sum() / self.batch_size
return outputs, loss if perform_update else torch.Tensor([0.]), perform_update
@torch.enable_grad()
def forward_and_adapt(self, x):
if self.mixed_precision and self.device == "cuda":
with torch.cuda.amp.autocast():
outputs, loss = self.loss_calculation(x)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
else:
outputs, loss, perform_update = self.loss_calculation(x)
if perform_update:
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if perform_update:
self.model = ema_update_model(
model_to_update=self.model,
model_to_merge=self.src_model,
momentum=self.momentum_src,
device=self.device
)
with torch.no_grad():
if self.use_prior_correction:
prior = outputs.softmax(1).mean(0)
smooth = max(1 / outputs.shape[0], 1 / outputs.shape[1]) / torch.max(prior)
smoothed_prior = (prior + smooth) / (1 + smooth * outputs.shape[1])
outputs *= smoothed_prior
return {'output': outputs, 'loss' : loss.item()}
def reset(self):
if self.model_states is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
self.load_model_and_optimizer()
self.class_probs_ema = 1 / self.num_classes * torch.ones(self.num_classes).to(self.device)
def collect_params(self):
"""Collect the affine scale + shift parameters from normalization layers.
Walk the model's modules and collect all normalization parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in self.model.named_modules():
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
for np, p in m.named_parameters():
if np in ['weight', 'bias'] and p.requires_grad:
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def configure_model(self):
"""Configure model."""
self.model.eval() # eval mode to avoid stochastic depth in swin. test-time normalization is still applied
self.model.requires_grad_(False) # disable grad, to (re-)enable only necessary parts
# re-enable gradient for normalization layers
for m in self.model.modules():
if isinstance(m, nn.BatchNorm2d):
m.requires_grad_(True)
# force use of batch stats in train and eval modes
m.track_running_stats = False
m.running_mean = None
m.running_var = None
elif isinstance(m, nn.BatchNorm1d):
m.train() # always forcing train mode in bn1d will cause problems for single sample tta
m.requires_grad_(True)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
m.requires_grad_(True)
|