Debug-XAI / backend /models /base.py
rongyuan
Update 1st version of UI.
89280a9
raw
history blame contribute delete
882 Bytes
from abc import ABC, abstractmethod
import torch.nn as nn
class LayerDecomposer(ABC):
"""
Abstract base class for decomposing Transformer layers into
Attention (Part 1) and MLP (Part 2) components.
"""
@abstractmethod
def get_mid_activation_module(self, layer_module):
"""
Returns the module whose input corresponds to 'resid_mid'.
This is typically the Post-Attention LayerNorm.
"""
pass
@abstractmethod
def forward_part1(self, layer_module, hidden_states, position_embeddings=None, attention_mask=None):
"""
Executes: Norm -> Attn -> Residual Add
Returns: resid_mid
"""
pass
@abstractmethod
def forward_part2(self, layer_module, hidden_states):
"""
Executes: Norm -> MLP -> Residual Add
Returns: resid_post
"""
pass