| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .mlp import LlamaMLP | |
| from .config import LlamaConfig | |
| from .rms_norm import LlamaRMSNorm | |
| from .attention import LlamaAttention | |
| from .diff_attn import DifferentialAttention | |
| from .tensor_prod_attn import CausalTensorProductSelfAttn | |
| class LlamaDecoderLayer(nn.Module): | |
| def __init__(self, config: LlamaConfig, layer_num): | |
| super().__init__() | |
| self.self_attn = CausalTensorProductSelfAttn(config) | |
| self.mlp = LlamaMLP(config) | |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm(hidden_states) | |
| hidden_states = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states |