from copy import deepcopy from methods.base import TTAMethod, forward_decorator from utils.registry import ADAPTATION_REGISTRY @ADAPTATION_REGISTRY.register() class Source(TTAMethod): def __init__(self, cfg, model, num_classes): super().__init__(cfg, model, num_classes) @forward_decorator def forward_and_adapt(self, x): imgs_test = x[0] return self.model(imgs_test) def copy_model_and_optimizer(self): """Copy the model and optimizer states for resetting after adaptation.""" model_states = [deepcopy(model.state_dict()) for model in self.models] optimizer_state = None return model_states, optimizer_state def reset(self): for model, model_state in zip(self.models, self.model_states): model.load_state_dict(model_state, strict=True) def configure_model(self): self.model.eval() self.model.requires_grad_(False)