| import torch | |
| import torch.nn as nn | |
| from .attn import FlexAttention | |
| from .modules import MLP | |
| from .modules import Attention | |
| try: | |
| from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP | |
| triton_mlp = True | |
| except ImportError as e: | |
| print( | |
| f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." | |
| ) | |
| triton_mlp = False | |
| try: | |
| from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm | |
| triton_norm = True | |
| except ImportError as e: | |
| print( | |
| f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." | |
| ) | |
| from torch.nn import RMSNorm | |
| triton_norm = False | |
| class AttentionLayer(nn.Module): | |
| def __init__(self, config, mask_mod, score_mod=None) -> None: | |
| super(AttentionLayer, self).__init__() | |
| self.attn_norm = nn.RMSNorm(config.dim) | |
| self.attn = FlexAttention( | |
| config=config, | |
| mask_mod=mask_mod, | |
| score_mod=score_mod, | |
| ) | |
| self.mlp_norm = nn.RMSNorm(config.dim) | |
| self.mlp = MLP(config) | |
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor: | |
| x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis) | |
| x = x + self.mlp(self.mlp_norm(x)) | |
| return x | |