PebbleLM-117M / src /model /decoder.py
nameissakthi's picture
Add model architecture code
27871e7
"""
Decoder Block for SLM.
Pre-norm architecture with residual connections.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
from .config import SLMConfig
from .normalization import RMSNorm
from .attention import MultiHeadAttention
from .ffn import FeedForward
from .kv_cache import KVCache
class DecoderBlock(nn.Module):
"""Single decoder block with pre-norm architecture.
Structure (Pre-Norm):
```
x
β”œβ”€ RMSNorm
β”œβ”€ Multi-Head Attention
β”œβ”€ Residual Add
β”œβ”€ RMSNorm
β”œβ”€ Feed-Forward Network
└─ Residual Add
```
Why pre-norm:
- More stable gradients in FP16 training
- Better quantization behavior
- Easier ONNX export (no layer-crossing dependencies)
"""
def __init__(self, config: SLMConfig, layer_idx: int):
"""Initialize decoder block.
Args:
config: Model configuration
layer_idx: Index of this layer
"""
super().__init__()
self.config = config
self.layer_idx = layer_idx
# Pre-attention norm
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Self-attention
self.self_attn = MultiHeadAttention(config, layer_idx)
# Pre-FFN norm
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Feed-forward network
self.mlp = FeedForward(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
"""Forward pass through decoder block.
Args:
hidden_states: Input tensor [batch, seq, hidden_size]
position_ids: Position indices [batch, seq]
attention_mask: Causal attention mask
kv_cache: Optional KV cache
use_cache: Whether to use/update cache
Returns:
Tuple of (output, kv_cache)
"""
# Store residual
residual = hidden_states
# Pre-norm -> Attention
hidden_states = self.input_layernorm(hidden_states)
hidden_states, kv_cache = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Residual connection
hidden_states = residual + hidden_states
# Store residual
residual = hidden_states
# Pre-norm -> FFN
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# Residual connection
hidden_states = residual + hidden_states
return hidden_states, kv_cache