File size: 937 Bytes
02ba886 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from copy import deepcopy
from methods.base import TTAMethod, forward_decorator
from utils.registry import ADAPTATION_REGISTRY
@ADAPTATION_REGISTRY.register()
class Source(TTAMethod):
def __init__(self, cfg, model, num_classes):
super().__init__(cfg, model, num_classes)
@forward_decorator
def forward_and_adapt(self, x):
imgs_test = x[0]
return self.model(imgs_test)
def copy_model_and_optimizer(self):
"""Copy the model and optimizer states for resetting after adaptation."""
model_states = [deepcopy(model.state_dict()) for model in self.models]
optimizer_state = None
return model_states, optimizer_state
def reset(self):
for model, model_state in zip(self.models, self.model_states):
model.load_state_dict(model_state, strict=True)
def configure_model(self):
self.model.eval()
self.model.requires_grad_(False)
|