|
|
from contextlib import contextmanager
|
|
|
from typing import List, Union
|
|
|
|
|
|
import torch
|
|
|
|
|
|
midlayer_dict = {
|
|
|
'vgg16': 'features.16',
|
|
|
'vgg19': 'features.18',
|
|
|
'resnet152': 'layer1',
|
|
|
'densenet169': 'features.denseblock2',
|
|
|
'inception_v3': 'Mixed_6c',
|
|
|
'resnet50': 'conv1',
|
|
|
'resnet50_cl': 'layer4',
|
|
|
'resnet32': 'conv_1_3x3',
|
|
|
'resnet32_cl': 'conv_1_3x3',
|
|
|
}
|
|
|
|
|
|
|
|
|
def register_collecter(m: torch.nn.Module, layer: str, feat_collecter: List):
|
|
|
|
|
|
def _hook(m, i, o):
|
|
|
feat_collecter.append(o)
|
|
|
|
|
|
_handler = m.get_submodule(layer).register_forward_hook(_hook)
|
|
|
return _handler, feat_collecter
|
|
|
|
|
|
def register_collecter_cl(m: torch.nn.Module, layer: str, feat_collecter: List, cl_methods: str):
|
|
|
|
|
|
def _hook(m, i, o):
|
|
|
feat_collecter.append(o)
|
|
|
|
|
|
if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic':
|
|
|
_handler = m.convnet.get_submodule(layer).register_forward_hook(_hook)
|
|
|
elif cl_methods == 'foster' or cl_methods == 'der':
|
|
|
_handler = m.convnets[0].get_submodule(layer).register_forward_hook(_hook)
|
|
|
elif cl_methods == 'memo':
|
|
|
_handler = m.TaskAgnosticExtractor.get_submodule(layer).register_forward_hook(_hook)
|
|
|
return _handler, feat_collecter
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def feat_col(m: Union[torch.nn.Module, List[torch.nn.Module]],
|
|
|
layer: Union[str, List[str]]):
|
|
|
if isinstance(m, torch.nn.Module):
|
|
|
m = [m]
|
|
|
if isinstance(layer, str):
|
|
|
layer = [layer]
|
|
|
assert len(m) == len(layer)
|
|
|
handlers = []
|
|
|
feat_collecter = []
|
|
|
for _m, _layer in zip(m, layer):
|
|
|
handler, feat_collecter = register_collecter(_m, _layer,
|
|
|
feat_collecter)
|
|
|
handlers.append(handler)
|
|
|
yield feat_collecter
|
|
|
for handler in handlers:
|
|
|
handler.remove()
|
|
|
feat_collecter.clear()
|
|
|
del feat_collecter
|
|
|
|