|
|
""" |
|
|
Decoder Block for SLM. |
|
|
Pre-norm architecture with residual connections. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
from .config import SLMConfig |
|
|
from .normalization import RMSNorm |
|
|
from .attention import MultiHeadAttention |
|
|
from .ffn import FeedForward |
|
|
from .kv_cache import KVCache |
|
|
|
|
|
|
|
|
class DecoderBlock(nn.Module): |
|
|
"""Single decoder block with pre-norm architecture. |
|
|
|
|
|
Structure (Pre-Norm): |
|
|
``` |
|
|
x |
|
|
ββ RMSNorm |
|
|
ββ Multi-Head Attention |
|
|
ββ Residual Add |
|
|
ββ RMSNorm |
|
|
ββ Feed-Forward Network |
|
|
ββ Residual Add |
|
|
``` |
|
|
|
|
|
Why pre-norm: |
|
|
- More stable gradients in FP16 training |
|
|
- Better quantization behavior |
|
|
- Easier ONNX export (no layer-crossing dependencies) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SLMConfig, layer_idx: int): |
|
|
"""Initialize decoder block. |
|
|
|
|
|
Args: |
|
|
config: Model configuration |
|
|
layer_idx: Index of this layer |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
|
|
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.self_attn = MultiHeadAttention(config, layer_idx) |
|
|
|
|
|
|
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.mlp = FeedForward(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
kv_cache: Optional[KVCache] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[KVCache]]: |
|
|
"""Forward pass through decoder block. |
|
|
|
|
|
Args: |
|
|
hidden_states: Input tensor [batch, seq, hidden_size] |
|
|
position_ids: Position indices [batch, seq] |
|
|
attention_mask: Causal attention mask |
|
|
kv_cache: Optional KV cache |
|
|
use_cache: Whether to use/update cache |
|
|
|
|
|
Returns: |
|
|
Tuple of (output, kv_cache) |
|
|
""" |
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
hidden_states, kv_cache = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
kv_cache=kv_cache, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
hidden_states = self.post_attention_layernorm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states, kv_cache |
|
|
|