Ttius's picture
Upload 192 files
998bb30 verified
from contextlib import contextmanager
from typing import List, Union
import torch
midlayer_dict = {
'vgg16': 'features.16',
'vgg19': 'features.18',
'resnet152': 'layer1',
'densenet169': 'features.denseblock2',
'inception_v3': 'Mixed_6c',
'resnet50': 'conv1', # conv1, layer1, layer2, layer3, layer4
'resnet50_cl': 'layer4',
'resnet32': 'conv_1_3x3', # conv_1_3x3, stage_1, stage_2, stage_3
'resnet32_cl': 'conv_1_3x3',
}
def register_collecter(m: torch.nn.Module, layer: str, feat_collecter: List):
def _hook(m, i, o):
feat_collecter.append(o)
_handler = m.get_submodule(layer).register_forward_hook(_hook) #m.convnets[0].get_submodule(layer).register_forward_hook(_hook)
return _handler, feat_collecter
def register_collecter_cl(m: torch.nn.Module, layer: str, feat_collecter: List, cl_methods: str):
def _hook(m, i, o):
feat_collecter.append(o)
if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
_handler = m.convnet.get_submodule(layer).register_forward_hook(_hook) # For ewc, icarl methods
elif cl_methods == 'foster' or cl_methods == 'der':
_handler = m.convnets[0].get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
elif cl_methods == 'memo':
_handler = m.TaskAgnosticExtractor.get_submodule(layer).register_forward_hook(_hook) # For foster and der methods
return _handler, feat_collecter
@contextmanager
def feat_col(m: Union[torch.nn.Module, List[torch.nn.Module]],
layer: Union[str, List[str]]):
if isinstance(m, torch.nn.Module):
m = [m]
if isinstance(layer, str):
layer = [layer]
assert len(m) == len(layer)
handlers = []
feat_collecter = []
for _m, _layer in zip(m, layer):
handler, feat_collecter = register_collecter(_m, _layer,
feat_collecter)
handlers.append(handler)
yield feat_collecter
for handler in handlers:
handler.remove()
feat_collecter.clear()
del feat_collecter