| from abc import ABC, abstractmethod | |
| import torch.nn as nn | |
| class LayerDecomposer(ABC): | |
| """ | |
| Abstract base class for decomposing Transformer layers into | |
| Attention (Part 1) and MLP (Part 2) components. | |
| """ | |
| def get_mid_activation_module(self, layer_module): | |
| """ | |
| Returns the module whose input corresponds to 'resid_mid'. | |
| This is typically the Post-Attention LayerNorm. | |
| """ | |
| pass | |
| def forward_part1(self, layer_module, hidden_states, position_embeddings=None, attention_mask=None): | |
| """ | |
| Executes: Norm -> Attn -> Residual Add | |
| Returns: resid_mid | |
| """ | |
| pass | |
| def forward_part2(self, layer_module, hidden_states): | |
| """ | |
| Executes: Norm -> MLP -> Residual Add | |
| Returns: resid_post | |
| """ | |
| pass | |