from contextlib import contextmanager from typing import List, Union import torch midlayer_dict = { 'vgg16': 'features.16', 'vgg19': 'features.18', 'resnet152': 'layer1', 'densenet169': 'features.denseblock2', 'inception_v3': 'Mixed_6c', 'resnet50': 'conv1', # conv1, layer1, layer2, layer3, layer4 'resnet50_cl': 'layer4', 'resnet32': 'conv_1_3x3', # conv_1_3x3, stage_1, stage_2, stage_3 'resnet32_cl': 'conv_1_3x3', } def register_collecter(m: torch.nn.Module, layer: str, feat_collecter: List): def _hook(m, i, o): feat_collecter.append(o) _handler = m.get_submodule(layer).register_forward_hook(_hook) #m.convnets[0].get_submodule(layer).register_forward_hook(_hook) return _handler, feat_collecter def register_collecter_cl(m: torch.nn.Module, layer: str, feat_collecter: List, cl_methods: str): def _hook(m, i, o): feat_collecter.append(o) if cl_methods == 'icarl' or cl_methods == 'finetune' or cl_methods == 'wa' or cl_methods == 'replay' or cl_methods == 'podnet' or cl_methods == 'bic': _handler = m.convnet.get_submodule(layer).register_forward_hook(_hook) # For ewc, icarl methods elif cl_methods == 'foster' or cl_methods == 'der': _handler = m.convnets[0].get_submodule(layer).register_forward_hook(_hook) # For foster and der methods elif cl_methods == 'memo': _handler = m.TaskAgnosticExtractor.get_submodule(layer).register_forward_hook(_hook) # For foster and der methods return _handler, feat_collecter @contextmanager def feat_col(m: Union[torch.nn.Module, List[torch.nn.Module]], layer: Union[str, List[str]]): if isinstance(m, torch.nn.Module): m = [m] if isinstance(layer, str): layer = [layer] assert len(m) == len(layer) handlers = [] feat_collecter = [] for _m, _layer in zip(m, layer): handler, feat_collecter = register_collecter(_m, _layer, feat_collecter) handlers.append(handler) yield feat_collecter for handler in handlers: handler.remove() feat_collecter.clear() del feat_collecter