File size: 2,948 Bytes
27871e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
"""
Decoder Block for SLM.
Pre-norm architecture with residual connections.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from .config import SLMConfig
from .normalization import RMSNorm
from .attention import MultiHeadAttention
from .ffn import FeedForward
from .kv_cache import KVCache
class DecoderBlock(nn.Module):
"""Single decoder block with pre-norm architecture.
Structure (Pre-Norm):
```
x
ββ RMSNorm
ββ Multi-Head Attention
ββ Residual Add
ββ RMSNorm
ββ Feed-Forward Network
ββ Residual Add
```
Why pre-norm:
- More stable gradients in FP16 training
- Better quantization behavior
- Easier ONNX export (no layer-crossing dependencies)
"""
def __init__(self, config: SLMConfig, layer_idx: int):
"""Initialize decoder block.
Args:
config: Model configuration
layer_idx: Index of this layer
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Pre-attention norm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Self-attention
self.self_attn = MultiHeadAttention(config, layer_idx)
# Pre-FFN norm
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Feed-forward network
self.mlp = FeedForward(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
"""Forward pass through decoder block.
Args:
hidden_states: Input tensor [batch, seq, hidden_size]
position_ids: Position indices [batch, seq]
attention_mask: Causal attention mask
kv_cache: Optional KV cache
use_cache: Whether to use/update cache
Returns:
Tuple of (output, kv_cache)
"""
# Store residual
residual = hidden_states
# Pre-norm -> Attention
hidden_states = self.input_layernorm(hidden_states)
hidden_states, kv_cache = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Residual connection
hidden_states = residual + hidden_states
# Store residual
residual = hidden_states
# Pre-norm -> FFN
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# Residual connection
hidden_states = residual + hidden_states
return hidden_states, kv_cache
|