File size: 882 Bytes
89280a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from abc import ABC, abstractmethod
import torch.nn as nn

class LayerDecomposer(ABC):
    """
    Abstract base class for decomposing Transformer layers into
    Attention (Part 1) and MLP (Part 2) components.
    """

    @abstractmethod
    def get_mid_activation_module(self, layer_module):
        """
        Returns the module whose input corresponds to 'resid_mid'.
        This is typically the Post-Attention LayerNorm.
        """
        pass

    @abstractmethod
    def forward_part1(self, layer_module, hidden_states, position_embeddings=None, attention_mask=None):
        """
        Executes: Norm -> Attn -> Residual Add
        Returns: resid_mid
        """
        pass

    @abstractmethod
    def forward_part2(self, layer_module, hidden_states):
        """
        Executes: Norm -> MLP -> Residual Add
        Returns: resid_post
        """
        pass