| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
|
|
| from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES |
|
|
|
|
| def _get_identifiable_transformer_blocks_in_module(module: torch.nn.Module): |
| module_list_with_transformer_blocks = [] |
| for name, submodule in module.named_modules(): |
| name_endswith_identifier = any(name.endswith(identifier) for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS) |
| is_ModuleList = isinstance(submodule, torch.nn.ModuleList) |
| if name_endswith_identifier and is_ModuleList: |
| module_list_with_transformer_blocks.append((name, submodule)) |
| return module_list_with_transformer_blocks |
|
|
|
|
| def _get_identifiable_attention_layers_in_module(module: torch.nn.Module): |
| attention_layers = [] |
| for name, submodule in module.named_modules(): |
| if isinstance(submodule, _ATTENTION_CLASSES): |
| attention_layers.append((name, submodule)) |
| return attention_layers |
|
|
|
|
| def _get_identifiable_feedforward_layers_in_module(module: torch.nn.Module): |
| feedforward_layers = [] |
| for name, submodule in module.named_modules(): |
| if isinstance(submodule, _FEEDFORWARD_CLASSES): |
| feedforward_layers.append((name, submodule)) |
| return feedforward_layers |
|
|