# Normalization Layers via the Method of Fluxions ## BatchNorm, LayerNorm, RMSNorm: What They Actually Do **Scott Bisset, Silicon Goddess** OpenTransformers Ltd January 2026 --- ## Abstract Normalization layers (BatchNorm, LayerNorm, RMSNorm) are presented in textbooks as "normalize then scale and shift" with formulas involving means and variances. This obscures their true purpose and makes backward pass derivation seem magical. We reformulate normalization using fluxions, revealing: (1) normalization as signal conditioning, (2) the backward pass as sensitivity redistribution, and (3) why different norms suit different architectures. The fluxion view also explains the computational structure that enables fused kernels. --- ## 1. Why Normalize? ### 1.1 The Problem Deep networks suffer from internal covariate shift: - Each layer's input distribution changes during training - Later layers constantly adapt to moving targets - Training becomes unstable ### 1.2 The Solution Intuition Force each layer's inputs to have consistent statistics. "Standardize the signal before processing it." --- ## 2. The General Normalization Framework ### 2.1 Forward Pass All normalization layers follow: ``` 1. Compute statistics (μ, σ) over some dimension 2. Normalize: x̂ = (x - μ) / σ 3. Scale and shift: y = γ·x̂ + β ``` What differs is WHICH dimensions we compute statistics over. ### 2.2 Fluxion View Let x be the input signal. ``` μ = mean(x) # Signal center σ = std(x) # Signal spread x̂ = (x - μ) / σ # Centered and scaled to unit variance y = γ·x̂ + β # Learnable rescaling ``` **γ** (gamma) = learned amplitude **β** (beta) = learned offset Without γ and β, normalization would constrain representational power. With them, the network can learn to undo normalization if needed. --- ## 3. BatchNorm: Normalize Across Batch ### 3.1 The Idea For each feature/channel, compute mean and variance ACROSS the batch. ``` Input: X of shape [B, D] (B samples, D features) For each feature d: μ_d = (1/B) Σᵢ X[i,d] # Mean of feature d across batch σ_d = sqrt((1/B) Σᵢ (X[i,d] - μ_d)²) # Std of feature d X̂[:,d] = (X[:,d] - μ_d) / σ_d Y[:,d] = γ_d · X̂[:,d] + β_d ``` ### 3.2 Fluxion Forward Pass ``` μ = mean(X, dim=batch) # Shape: [D] σ = std(X, dim=batch) # Shape: [D] X̂ = (X - μ) / σ # Shape: [B, D] Y = γ·X̂ + β # Shape: [B, D] ``` ### 3.3 The Backward Pass (Fluxion Derivation) Given L̇ʸ (upstream gradient), find L̇ˣ, L̇ᵞ, L̇ᵝ. **Easy ones first:** ``` L̇ᵝ = sum(L̇ʸ, dim=batch) # β gradient = sum of upstream L̇ᵞ = sum(L̇ʸ · X̂, dim=batch) # γ gradient = upstream weighted by normalized input ``` **L̇ˣ is tricky because μ and σ depend on ALL x values.** Let's trace the wiggle: ``` If x[i,d] wiggles: 1. Direct effect on X̂[i,d]: ∂X̂[i,d]/∂x[i,d] = 1/σ 2. Indirect effect via μ: changing x[i,d] shifts μ, affects ALL X̂[:,d] 3. Indirect effect via σ: changing x[i,d] changes σ, affects ALL X̂[:,d] ``` Full derivative: ``` L̇ˣ̂ = L̇ʸ · γ # Gradient through scale L̇σ = -sum(L̇ˣ̂ · (X-μ) / σ², dim=batch) # How σ wiggle affects loss L̇μ = -sum(L̇ˣ̂ / σ, dim=batch) + L̇σ · (-2/B)·sum(X-μ, dim=batch) L̇ˣ = L̇ˣ̂/σ + L̇σ·(2/B)·(X-μ)/σ + L̇μ/B ``` ### 3.4 Simplified Form After algebra, the BatchNorm backward becomes: ``` L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂) - X̂·mean(L̇ˣ̂·X̂)) ``` **Interpretation:** - Start with scaled upstream gradient - Subtract its mean (center the gradient) - Subtract correlation with normalized input (decorrelate) "BatchNorm backward REDISTRIBUTES gradient to maintain zero-mean, unit-variance gradient flow." ### 3.5 Inference Mode At inference, we don't have a batch. Use running averages from training: ``` μ_running = momentum·μ_running + (1-momentum)·μ_batch σ_running = momentum·σ_running + (1-momentum)·σ_batch ``` Then normalize with running stats. ### 3.6 Problems with BatchNorm 1. **Batch size dependence**: Small batches → noisy statistics 2. **Not suitable for sequence models**: Each position needs different batch members 3. **Inference/training mismatch**: Running stats ≠ batch stats --- ## 4. LayerNorm: Normalize Across Features ### 4.1 The Idea For each sample, compute mean and variance ACROSS features. ``` Input: X of shape [B, D] For each sample i: μᵢ = (1/D) Σ_d X[i,d] # Mean across features σᵢ = sqrt((1/D) Σ_d (X[i,d] - μᵢ)²) X̂[i,:] = (X[i,:] - μᵢ) / σᵢ Y[i,:] = γ · X̂[i,:] + β ``` ### 4.2 Fluxion Forward Pass ``` μ = mean(X, dim=features) # Shape: [B] σ = std(X, dim=features) # Shape: [B] X̂ = (X - μ) / σ # Shape: [B, D] Y = γ·X̂ + β # Shape: [B, D] ``` ### 4.3 Key Difference from BatchNorm ``` BatchNorm: statistics across BATCH (each feature normalized independently) LayerNorm: statistics across FEATURES (each sample normalized independently) ``` ### 4.4 Why LayerNorm for Transformers? 1. **No batch dependence**: Each token normalized independently 2. **Works with any batch size**: Including batch=1 at inference 3. **Sequence-friendly**: Position i doesn't need position j's statistics ### 4.5 Backward Pass Same structure as BatchNorm, but sum over features instead of batch: ``` L̇ˣ̂ = L̇ʸ · γ L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂, dim=features) - X̂·mean(L̇ˣ̂·X̂, dim=features)) ``` --- ## 5. RMSNorm: Skip the Mean ### 5.1 The Simplification LayerNorm computes mean AND variance. RMSNorm: "What if we skip the mean centering?" ``` RMS(x) = sqrt(mean(x²)) X̂ = X / RMS(X) Y = γ · X̂ ``` No β parameter (no shift), no mean subtraction. ### 5.2 Fluxion Forward Pass ``` rms = sqrt(mean(X², dim=features)) # Shape: [B] X̂ = X / rms # Shape: [B, D] Y = γ · X̂ # Shape: [B, D] ``` ### 5.3 Why It Works Empirically, the mean-centering in LayerNorm contributes little. RMSNorm achieves similar performance with: - Fewer operations - Simpler backward pass - Better numerical stability ### 5.4 Backward Pass Much simpler without mean: ``` L̇ˣ̂ = L̇ʸ · γ L̇ʳᵐˢ = -sum(L̇ˣ̂ · X / rms²) L̇ˣ = L̇ˣ̂/rms + L̇ʳᵐˢ · X/(D·rms) ``` Simplified: ``` L̇ˣ = (1/rms) · (L̇ˣ̂ - X̂·mean(L̇ˣ̂·X̂)) ``` One fewer term than LayerNorm! ### 5.5 Usage RMSNorm is used in: - LLaMA - Mistral - Most modern LLMs It's becoming the default for transformers. --- ## 6. Comparison Table | Property | BatchNorm | LayerNorm | RMSNorm | |----------|-----------|-----------|---------| | Stats over | Batch | Features | Features | | Learnable | γ, β | γ, β | γ only | | Mean centering | Yes | Yes | No | | Batch dependent | Yes | No | No | | Inference mode | Running stats | Same as training | Same as training | | Use case | CNNs | Transformers | Modern LLMs | | Operations | Most | Medium | Fewest | --- ## 7. Pre-Norm vs Post-Norm ### 7.1 Post-Norm (Original Transformer) ``` X → Attention → Add(X) → LayerNorm → FFN → Add → LayerNorm → Output ``` Normalize AFTER residual connection. ### 7.2 Pre-Norm (Modern Default) ``` X → LayerNorm → Attention → Add(X) → LayerNorm → FFN → Add → Output ``` Normalize BEFORE attention/FFN. ### 7.3 Fluxion Analysis **Post-Norm gradient flow:** ``` L̇ˣ = LayerNorm_backward(L̇ᵒᵘᵗ) # Gradient must flow through norm ``` **Pre-Norm gradient flow:** ``` L̇ˣ = L̇ᵒᵘᵗ + LayerNorm_backward(Attention_backward(L̇)) ↑ Direct highway! ``` Pre-Norm has a direct gradient path that bypasses normalization. This stabilizes training for deep networks. --- ## 8. Numerical Stability ### 8.1 The Variance Problem Computing variance naively: ``` var = mean(x²) - mean(x)² ``` If mean(x²) ≈ mean(x)², subtraction causes catastrophic cancellation. ### 8.2 Welford's Algorithm Compute variance in a single pass, numerically stable: ```python def welford_var(x): n = 0 mean = 0 M2 = 0 for xi in x: n += 1 delta = xi - mean mean += delta / n delta2 = xi - mean M2 += delta * delta2 return M2 / n ``` ### 8.3 Fused Kernels The fluxion view reveals that forward and backward are tightly coupled: Forward needs: μ, σ, X̂ Backward needs: μ, σ, X̂ (same!) Fused kernel can: 1. Compute μ, σ in one pass 2. Store only X̂ (derived from X, μ, σ) 3. Backward reuses X̂ directly This is why PyTorch's native LayerNorm is much faster than naive implementation. --- ## 9. Gradient Flow Analysis ### 9.1 Without Normalization Deep network gradient flow: ``` L̇ˣ⁽⁰⁾ = W⁽¹⁾ᵀ · W⁽²⁾ᵀ · ... · W⁽ᴸ⁾ᵀ · L̇ʸ ``` If weights are slightly > 1: gradient explodes If weights are slightly < 1: gradient vanishes ### 9.2 With Normalization Each layer's activations are forced to unit variance. Gradient magnitudes stabilize. ``` L̇ˣ̂ has unit variance (approximately) → L̇ˣ has controlled magnitude → No explosion or vanishing ``` ### 9.3 The Jacobian View LayerNorm Jacobian ∂Y/∂X is NOT diagonal (because of mean/var coupling). But it has a special structure: ``` J = (1/σ) · (I - (1/D)·1·1ᵀ - (1/D)·X̂·X̂ᵀ) ``` This projects out the mean direction and decorrelates from X̂. Eigenvalues are bounded, preventing gradient explosion. --- ## 10. Implementation Details ### 10.1 Memory Layout Matters LayerNorm over last dimension (features): ```python # Contiguous memory access pattern for b in range(B): for d in range(D): # Sequential access ... ``` BatchNorm over batch dimension: ```python # Strided memory access pattern for d in range(D): for b in range(B): # Jumping through memory ... ``` LayerNorm is more cache-friendly for typical tensor layouts [B, ..., D]. ### 10.2 Epsilon Placement ```python # Wrong (can still divide by zero if var=0): x_hat = (x - mean) / sqrt(var + eps) # Right (always safe): x_hat = (x - mean) / (sqrt(var) + eps) # Also right (fused): x_hat = (x - mean) * rsqrt(var + eps) ``` The `rsqrt` (reciprocal square root) is a single GPU instruction. --- ## 11. Summary ### 11.1 Fluxion View of Normalization **Forward:** ``` μ̇, σ̇ computed from X X̂ = (X - μ)/σ # Standardize Y = γ·X̂ + β # Scale and shift ``` **Backward:** ``` L̇ᵞ = sum(L̇ʸ · X̂) # Scale gradient L̇ᵝ = sum(L̇ʸ) # Shift gradient L̇ˣ = redistributed gradient (centered, decorrelated) ``` ### 11.2 Key Insight Normalization doesn't just scale activations—it COUPLES gradient flow across the normalized dimension. Each input's gradient depends on ALL other inputs in the normalization group. This coupling: - Stabilizes gradient magnitudes - Prevents single features from dominating - Enables deeper networks --- ## References 1. Ioffe & Szegedy (2015). "Batch Normalization: Accelerating Deep Network Training." 2. Ba, Kiros & Hinton (2016). "Layer Normalization." 3. Zhang & Sennrich (2019). "Root Mean Square Layer Normalization." 4. Xiong et al. (2020). "On Layer Normalization in the Transformer Architecture." --- *Correspondence: scott@opentransformers.online*