File size: 2,948 Bytes
27871e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Decoder Block for SLM.
Pre-norm architecture with residual connections.
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple

from .config import SLMConfig
from .normalization import RMSNorm
from .attention import MultiHeadAttention
from .ffn import FeedForward
from .kv_cache import KVCache


class DecoderBlock(nn.Module):
    """Single decoder block with pre-norm architecture.

    Structure (Pre-Norm):
    ```
    x
     β”œβ”€ RMSNorm
     β”œβ”€ Multi-Head Attention
     β”œβ”€ Residual Add
     β”œβ”€ RMSNorm
     β”œβ”€ Feed-Forward Network
     └─ Residual Add
    ```

    Why pre-norm:
    - More stable gradients in FP16 training
    - Better quantization behavior
    - Easier ONNX export (no layer-crossing dependencies)
    """

    def __init__(self, config: SLMConfig, layer_idx: int):
        """Initialize decoder block.

        Args:
            config: Model configuration
            layer_idx: Index of this layer
        """
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx

        # Pre-attention norm
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Self-attention
        self.self_attn = MultiHeadAttention(config, layer_idx)

        # Pre-FFN norm
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        # Feed-forward network
        self.mlp = FeedForward(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        """Forward pass through decoder block.

        Args:
            hidden_states: Input tensor [batch, seq, hidden_size]
            position_ids: Position indices [batch, seq]
            attention_mask: Causal attention mask
            kv_cache: Optional KV cache
            use_cache: Whether to use/update cache

        Returns:
            Tuple of (output, kv_cache)
        """
        # Store residual
        residual = hidden_states

        # Pre-norm -> Attention
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states, kv_cache = self.self_attn(
            hidden_states=hidden_states,
            position_ids=position_ids,
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=use_cache,
        )

        # Residual connection
        hidden_states = residual + hidden_states

        # Store residual
        residual = hidden_states

        # Pre-norm -> FFN
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        # Residual connection
        hidden_states = residual + hidden_states

        return hidden_states, kv_cache