Image Classification
English
TTA
ReservoirTTA / methods /source.py
GuillaumeVray
Uploading files
02ba886
from copy import deepcopy
from methods.base import TTAMethod, forward_decorator
from utils.registry import ADAPTATION_REGISTRY
@ADAPTATION_REGISTRY.register()
class Source(TTAMethod):
def __init__(self, cfg, model, num_classes):
super().__init__(cfg, model, num_classes)
@forward_decorator
def forward_and_adapt(self, x):
imgs_test = x[0]
return self.model(imgs_test)
def copy_model_and_optimizer(self):
"""Copy the model and optimizer states for resetting after adaptation."""
model_states = [deepcopy(model.state_dict()) for model in self.models]
optimizer_state = None
return model_states, optimizer_state
def reset(self):
for model, model_state in zip(self.models, self.model_states):
model.load_state_dict(model_state, strict=True)
def configure_model(self):
self.model.eval()
self.model.requires_grad_(False)