File size: 6,386 Bytes
fb67af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Root Mean Square Layer Normalization (RMSNorm) implementation.

Critical implementation details:
1. Use multiplication with rsqrt, NOT division
2. No mean subtraction (unlike LayerNorm)
3. Compute in FP32 for numerical stability even when using BF16/FP16
"""

import torch
import torch.nn as nn
from typing import Optional


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.
    
    RMSNorm is a simplification of LayerNorm that removes the mean subtraction
    and only performs re-scaling via root mean square.
    
    Based on the paper: 'Root Mean Square Layer Normalization'
    https://arxiv.org/abs/1910.07467
    """
    
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        Args:
            hidden_size: Size of the hidden dimension
            eps: Small constant for numerical stability (1e-6 for BF16, 1e-5 for FP16)
        """
        super().__init__()
        self.hidden_size = hidden_size
        # CRITICAL FIX: Ensure eps is stored as float, not string
        self.eps = float(eps) if isinstance(eps, str) else eps
        
        # Learnable scale parameter (gamma)
        self.weight = nn.Parameter(torch.ones(hidden_size))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply RMSNorm to input tensor.
        
        CRITICAL BUG TO AVOID:
        The most common bug is using division with torch.rsqrt:
        WRONG: x / torch.rsqrt(variance + eps)  # This is x * sqrt(variance)
        RIGHT: x * torch.rsqrt(variance + eps)  # This is x / sqrt(variance)
        
        Args:
            x: Input tensor of shape [..., hidden_size]
        
        Returns:
            Normalized tensor of same shape as input
        """
        # Store original dtype (for mixed precision training)
        input_dtype = x.dtype
        
        # CRITICAL: Compute in float32 for numerical stability
        x_float32 = x.float()
        
        # Compute RMS (root mean square)
        # RMS = sqrt(mean(x^2))
        variance = x_float32.pow(2).mean(dim=-1, keepdim=True)

        # CRITICAL: Use rsqrt (reciprocal square root) with multiplication
        # rsqrt(x) = 1/sqrt(x), so x * rsqrt(variance) = x / sqrt(variance)
        # PERFORMANCE FIX: PyTorch automatically broadcasts scalars, no need for tensor()
        x_normalized = x_float32 * torch.rsqrt(variance + self.eps)
        
        # Apply learned scale and cast back to original dtype
        return self.weight * x_normalized.to(input_dtype)
    
    def extra_repr(self) -> str:
        return f'hidden_size={self.hidden_size}, eps={self.eps}'


class RMSNormOptimized(nn.Module):
    """Optimized RMSNorm with optional fused operations.
    
    This version includes optimizations for better performance:
    1. Option for in-place operations
    2. Support for sequence parallelism
    3. Optional residual connection fusion
    """
    
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        elementwise_affine: bool = True,
        memory_efficient: bool = False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        # CRITICAL FIX: Ensure eps is stored as float, not string
        self.eps = float(eps) if isinstance(eps, str) else eps
        self.elementwise_affine = elementwise_affine
        self.memory_efficient = memory_efficient
        
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(hidden_size))
        else:
            self.register_parameter('weight', None)
    
    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Apply RMSNorm with optional residual connection.
        
        Args:
            x: Input tensor
            residual: Optional residual to add before normalization
        
        Returns:
            Normalized tensor (and residual if provided)
        """
        # Add residual if provided (pre-norm architecture)
        if residual is not None:
            x = x + residual
            residual = x  # Save for skip connection
        
        # Original dtype for mixed precision
        input_dtype = x.dtype
        
        # Compute in FP32
        if self.memory_efficient:
            # In-place operations to save memory
            x = x.float()
            variance = x.pow_(2).mean(dim=-1, keepdim=True)
            # PERFORMANCE FIX: Use scalar directly
            x.mul_(torch.rsqrt(variance + self.eps))
        else:
            # Standard computation
            x_float32 = x.float()
            variance = x_float32.pow(2).mean(dim=-1, keepdim=True)
            # PERFORMANCE FIX: Use scalar directly
            x = x_float32 * torch.rsqrt(variance + self.eps)
        
        # Apply weight and cast back
        if self.elementwise_affine:
            x = self.weight * x
        
        x = x.to(input_dtype)
        
        if residual is not None:
            return x, residual
        return x


def rmsnorm_func(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    """Functional version of RMSNorm for use in torch.compile or custom kernels.
    
    This can be used with torch.compile for better optimization.
    """
    input_dtype = x.dtype
    x = x.float()
    variance = x.pow(2).mean(dim=-1, keepdim=True)
    # Ensure eps is properly handled
    eps_val = float(eps) if isinstance(eps, str) else eps
    x = x * torch.rsqrt(variance + eps_val)
    return (weight * x).to(input_dtype)


# Comparison with LayerNorm for reference
def compare_normalization():
    """Compare RMSNorm with LayerNorm to understand the differences."""
    import torch.nn as nn
    
    batch_size, seq_len, hidden = 2, 10, 768
    x = torch.randn(batch_size, seq_len, hidden)
    
    # LayerNorm: normalizes by mean and variance
    layer_norm = nn.LayerNorm(hidden)
    ln_out = layer_norm(x)
    
    # RMSNorm: normalizes by RMS only (no mean subtraction)
    rms_norm = RMSNorm(hidden)
    rms_out = rms_norm(x)
    
    print(f"Input shape: {x.shape}")
    print(f"LayerNorm output shape: {ln_out.shape}")
    print(f"RMSNorm output shape: {rms_out.shape}")
    print(f"Mean difference: {(ln_out - rms_out).abs().mean().item():.6f}")
    
    # RMSNorm is 15-20% faster due to simpler computation
    return ln_out, rms_out