""" Decoder Block for SLM. Pre-norm architecture with residual connections. """ import torch import torch.nn as nn from typing import Optional, Tuple from .config import SLMConfig from .normalization import RMSNorm from .attention import MultiHeadAttention from .ffn import FeedForward from .kv_cache import KVCache class DecoderBlock(nn.Module): """Single decoder block with pre-norm architecture. Structure (Pre-Norm): ``` x ├─ RMSNorm ├─ Multi-Head Attention ├─ Residual Add ├─ RMSNorm ├─ Feed-Forward Network └─ Residual Add ``` Why pre-norm: - More stable gradients in FP16 training - Better quantization behavior - Easier ONNX export (no layer-crossing dependencies) """ def __init__(self, config: SLMConfig, layer_idx: int): """Initialize decoder block. Args: config: Model configuration layer_idx: Index of this layer """ super().__init__() self.config = config self.layer_idx = layer_idx # Pre-attention norm self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Self-attention self.self_attn = MultiHeadAttention(config, layer_idx) # Pre-FFN norm self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Feed-forward network self.mlp = FeedForward(config) def forward( self, hidden_states: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[KVCache]]: """Forward pass through decoder block. Args: hidden_states: Input tensor [batch, seq, hidden_size] position_ids: Position indices [batch, seq] attention_mask: Causal attention mask kv_cache: Optional KV cache use_cache: Whether to use/update cache Returns: Tuple of (output, kv_cache) """ # Store residual residual = hidden_states # Pre-norm -> Attention hidden_states = self.input_layernorm(hidden_states) hidden_states, kv_cache = self.self_attn( hidden_states=hidden_states, position_ids=position_ids, attention_mask=attention_mask, kv_cache=kv_cache, use_cache=use_cache, ) # Residual connection hidden_states = residual + hidden_states # Store residual residual = hidden_states # Pre-norm -> FFN hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) # Residual connection hidden_states = residual + hidden_states return hidden_states, kv_cache