Image Classification
English
TTA
GuillaumeVray
Uploading files
02ba886
import logging
import torch
import torch.nn as nn
from torchvision import transforms
from copy import deepcopy
from functools import wraps
logger = logging.getLogger(__name__)
class TTAMethod(nn.Module):
def __init__(self, cfg, model, num_classes):
super().__init__()
self.cfg = cfg
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.decorate_model(model)
self.num_classes = num_classes
self.episodic = cfg.MODEL.EPISODIC
self.dataset_name = cfg.CORRUPTION.DATASET
self.steps = cfg.OPTIM.STEPS
self.current_grad_norm = 0.0
assert self.steps > 0, "requires >= 1 step(s) to forward and update"
# variables for resetting the model after a certain amount of performed update steps
self.performed_updates = 0
self.reset_after_num_updates = cfg.MODEL.RESET_AFTER_NUM_UPDATES
# restore the image input size from the model pre-processing if it is defined
# this is required for methods relying on test-time augmentation
if "cifar" in self.dataset_name:
self.img_size = (32, 32)
if "imagenet" in self.dataset_name or "ccc" in self.dataset_name:
self.img_size = (224, 224)
if hasattr(self.model, "model_preprocess") and isinstance(self.model.model_preprocess, transforms.Compose):
for transf in self.model.model_preprocess.transforms[::-1]:
if hasattr(transf, "size"):
self.img_size = getattr(transf, "size")
if self.dataset_name in ["imagenet_c", "ccc"] and max(self.img_size) > 224:
raise ValueError(f"The specified model with pre-processing {model.model_preprocess} "
f"is not suited in combination with ImageNet-C and CCC! "
f"These datasets are already resized and center cropped to 224")
break
# configure model and optimizer
self.configure_model()
self.params, self.param_names = self.collect_params()
self.optimizer = self.setup_optimizer() if len(self.params) > 0 or len(self.param_names) > 0 else None
self.num_trainable_params, self.num_total_params = self.get_number_trainable_params()
# variables needed for single sample test-time adaptation (sstta) using a sliding window (buffer) approach
self.input_buffer = None
self.window_length = cfg.TEST.WINDOW_LENGTH
self.pointer = torch.tensor([0], dtype=torch.long).to(self.device)
# sstta: if the model has no batchnorm layers, we do not need to forward the whole buffer when not performing any updates
self.has_bn = any([isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)) for m in model.modules()])
# note: if the self.model is never reset, like for continual adaptation,
# then skipping the state copy would save memory
self.models = [self.model]
self.model_states, self.optimizer_state = self.copy_model_and_optimizer()
# setup for mixed-precision or single precision
self.mixed_precision = cfg.MIXED_PRECISION
self.scaler = torch.cuda.amp.GradScaler() if cfg.MIXED_PRECISION else None
def decorate_model(self, model):
return model
def forward(self, x):
if self.episodic:
self.reset()
x = x if isinstance(x, list) else [x]
if x[0].shape[0] == 1: # single sample test-time adaptation
# create the sliding window input
if self.input_buffer is None:
self.input_buffer = [x_item for x_item in x]
# set bn1d layers into eval mode, since no statistics can be extracted from 1 sample
self.change_mode_of_batchnorm1d(self.models, to_train_mode=False)
elif self.input_buffer[0].shape[0] < self.window_length:
self.input_buffer = [torch.cat([self.input_buffer[i], x_item], dim=0) for i, x_item in enumerate(x)]
# set bn1d layers into train mode
self.change_mode_of_batchnorm1d(self.models, to_train_mode=True)
else:
for i, x_item in enumerate(x):
self.input_buffer[i][self.pointer] = x_item
if self.pointer == (self.window_length - 1):
# update the model, since the complete buffer has changed
for _ in range(self.steps):
outputs = self.forward_and_adapt(self.input_buffer)
# if specified, reset the model after a certain amount of update steps
self.performed_updates += 1
if self.reset_after_num_updates > 0 and self.performed_updates % self.reset_after_num_updates == 0:
self.reset()
outputs = outputs[self.pointer.long()]
else:
# create the prediction without updating the model
if self.has_bn:
# forward the whole buffer to get good batchnorm statistics
outputs = self.forward_sliding_window(self.input_buffer)
outputs = outputs[self.pointer.long()]
else:
# only forward the current test sample, since there are no batchnorm layers
outputs = self.forward_sliding_window(x)
# increase the pointer
self.pointer += 1
self.pointer %= self.window_length
else: # common batch adaptation setting
for _ in range(self.steps):
outputs = self.forward_and_adapt(x)
# if specified, reset the model after a certain amount of update steps
self.performed_updates += 1
if self.reset_after_num_updates > 0 and self.performed_updates % self.reset_after_num_updates == 0:
logger.info(f"Reset the model after {self.reset_after_num_updates} updates")
self.reset()
return outputs
def loss_calculation(self, x):
"""
Loss calculation.
"""
raise NotImplementedError
def forward_and_adapt(self, x):
"""
Forward and adapt the model on a batch of data.
"""
raise NotImplementedError
@torch.no_grad()
def forward_sliding_window(self, x):
"""
Create the prediction for single sample test-time adaptation with a sliding window
:param x: The buffered data created with a sliding window
:return: Model predictions
"""
imgs_test = x[0]
return self.model(imgs_test)
def configure_model(self):
raise NotImplementedError
def collect_params(self):
"""Collect all trainable parameters.
Walk the model's modules and collect all parameters.
Return the parameters and their names.
Note: other choices of parameterization are possible!
"""
params = []
names = []
for nm, m in self.model.named_modules():
for np, p in m.named_parameters():
if np in ['weight', 'bias', 'prompts'] and p.requires_grad:
params.append(p)
names.append(f"{nm}.{np}")
return params, names
def setup_optimizer(self):
if self.cfg.OPTIM.METHOD == 'Adam':
return torch.optim.Adam(self.params,
lr=self.cfg.OPTIM.LR,
betas=(self.cfg.OPTIM.BETA, 0.999),
weight_decay=self.cfg.OPTIM.WD)
elif self.cfg.OPTIM.METHOD == 'AdamW':
return torch.optim.AdamW(self.params,
lr=self.cfg.OPTIM.LR,
betas=(self.cfg.OPTIM.BETA, 0.999),
weight_decay=self.cfg.OPTIM.WD)
elif self.cfg.OPTIM.METHOD == 'SGD':
return torch.optim.SGD(self.params,
lr=self.cfg.OPTIM.LR,
momentum=self.cfg.OPTIM.MOMENTUM,
dampening=self.cfg.OPTIM.DAMPENING,
weight_decay=self.cfg.OPTIM.WD,
nesterov=self.cfg.OPTIM.NESTEROV)
else:
raise NotImplementedError
def get_number_trainable_params(self):
if isinstance(self.params, list):
trainable = sum(p.numel() for p in self.params) if len(self.params) > 0 else 0
elif isinstance(self.params, dict):
trainable = []
for _, param in self.params.items():
if len(param) > 0:
trainable.append(sum(p.numel() for p in param))
trainable = sum(trainable)
total = sum(p.numel() for p in self.model.parameters())
logger.info(f"#Trainable/total parameters: {trainable:,}/{total:,} \t Ratio: {trainable / total * 100:.3f}% ")
return trainable, total
def reset(self):
"""Reset the model and optimizer state to the initial source state"""
if self.model_states is None or self.optimizer_state is None:
raise Exception("cannot reset without saved model/optimizer state")
self.load_model_and_optimizer()
def copy_model_and_optimizer(self):
"""Copy the model and optimizer states for resetting after adaptation."""
model_states = [deepcopy(model.state_dict()) for model in self.models]
optimizer_state = deepcopy(self.optimizer.state_dict())
return model_states, optimizer_state
def load_model_and_optimizer(self):
"""Restore the model and optimizer states from copies."""
for model, model_state in zip(self.models, self.model_states):
model.load_state_dict(model_state, strict=True)
self.optimizer.load_state_dict(self.optimizer_state)
def save_model(self, save_path, r, errs_5, errs, current_domain_step, current_global_step, accuracy_buffer):
pass
def load_model(self, ckpt):
raise NotImplementedError
def average_grad_norm(self,):
raise NotImplementedError
@staticmethod
def copy_model(model):
is_parallel = isinstance(model, nn.DataParallel)
if is_parallel:
model = model.module
coppied_model = deepcopy(model)
for param in coppied_model.parameters():
param.detach_()
if is_parallel:
model = nn.DataParallel(model)
coppied_model = nn.DataParallel(coppied_model)
return coppied_model
@staticmethod
def change_mode_of_batchnorm1d(model_list, to_train_mode=True):
# batchnorm1d layers do not work with single sample inputs
for model in model_list:
for m in model.modules():
if isinstance(m, nn.BatchNorm1d):
if to_train_mode:
m.train()
else:
m.eval()
def forward_decorator(fn):
@wraps(fn)
def decorator(self, *args, **kwargs):
if self.mixed_precision:
with torch.cuda.amp.autocast():
outputs = fn(self, *args, **kwargs)
else:
outputs = fn(self, *args, **kwargs)
return outputs
return decorator