| from collections import OrderedDict | |
| import torch.nn as nn | |
| def layer_removal( | |
| model: nn.Module, | |
| layers_to_remove: OrderedDict | |
| ): | |
| """ | |
| Generic removal implementation | |
| """ | |
| for layer_name, layer_idx in layers_to_remove.items(): | |
| modules = layer_name.split(".") | |
| mod = model | |
| for m in modules[:-1]: | |
| mod = getattr(mod, m) | |
| if layer_idx is None: | |
| del getattr(mod, modules[-1]) | |
| else: | |
| del getattr(mod, modules[-1])[layer_idx] | |