import torch import torch.nn as nn from .attn import FlexAttention from .modules import MLP from .modules import Attention try: from liger_kernel.transformers.swiglu import LigerSwiGLUMLP as TritonMLP triton_mlp = True except ImportError as e: print( f"Unable to import Triton-based MLP: {e}. Falling back to vanilla SwiGLU MLP instead." ) triton_mlp = False try: from liger_kernel.transformers.rms_norm import LigerRMSNorm as TritonNorm triton_norm = True except ImportError as e: print( f"Unable to import Triton-based RMSNorm: {e}. Falling back to PyTorch implementation." ) from torch.nn import RMSNorm triton_norm = False class AttentionLayer(nn.Module): def __init__(self, config, mask_mod, score_mod=None) -> None: super(AttentionLayer, self).__init__() self.attn_norm = nn.RMSNorm(config.dim) self.attn = FlexAttention( config=config, mask_mod=mask_mod, score_mod=score_mod, ) self.mlp_norm = nn.RMSNorm(config.dim) self.mlp = MLP(config) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor: x = x + self.attn(self.attn_norm(x), freqs_cis=freqs_cis) x = x + self.mlp(self.mlp_norm(x)) return x