Normalization Layers via the Method of Fluxions
BatchNorm, LayerNorm, RMSNorm: What They Actually Do
Scott Bisset, Silicon Goddess
OpenTransformers Ltd
January 2026
Abstract
Normalization layers (BatchNorm, LayerNorm, RMSNorm) are presented in textbooks as "normalize then scale and shift" with formulas involving means and variances. This obscures their true purpose and makes backward pass derivation seem magical. We reformulate normalization using fluxions, revealing: (1) normalization as signal conditioning, (2) the backward pass as sensitivity redistribution, and (3) why different norms suit different architectures. The fluxion view also explains the computational structure that enables fused kernels.
1. Why Normalize?
1.1 The Problem
Deep networks suffer from internal covariate shift:
- Each layer's input distribution changes during training
- Later layers constantly adapt to moving targets
- Training becomes unstable
1.2 The Solution Intuition
Force each layer's inputs to have consistent statistics. "Standardize the signal before processing it."
2. The General Normalization Framework
2.1 Forward Pass
All normalization layers follow:
1. Compute statistics (μ, σ) over some dimension
2. Normalize: x̂ = (x - μ) / σ
3. Scale and shift: y = γ·x̂ + β
What differs is WHICH dimensions we compute statistics over.
2.2 Fluxion View
Let x be the input signal.
μ = mean(x) # Signal center
σ = std(x) # Signal spread
x̂ = (x - μ) / σ # Centered and scaled to unit variance
y = γ·x̂ + β # Learnable rescaling
γ (gamma) = learned amplitude β (beta) = learned offset
Without γ and β, normalization would constrain representational power. With them, the network can learn to undo normalization if needed.
3. BatchNorm: Normalize Across Batch
3.1 The Idea
For each feature/channel, compute mean and variance ACROSS the batch.
Input: X of shape [B, D] (B samples, D features)
For each feature d:
μ_d = (1/B) Σᵢ X[i,d] # Mean of feature d across batch
σ_d = sqrt((1/B) Σᵢ (X[i,d] - μ_d)²) # Std of feature d
X̂[:,d] = (X[:,d] - μ_d) / σ_d
Y[:,d] = γ_d · X̂[:,d] + β_d
3.2 Fluxion Forward Pass
μ = mean(X, dim=batch) # Shape: [D]
σ = std(X, dim=batch) # Shape: [D]
X̂ = (X - μ) / σ # Shape: [B, D]
Y = γ·X̂ + β # Shape: [B, D]
3.3 The Backward Pass (Fluxion Derivation)
Given L̇ʸ (upstream gradient), find L̇ˣ, L̇ᵞ, L̇ᵝ.
Easy ones first:
L̇ᵝ = sum(L̇ʸ, dim=batch) # β gradient = sum of upstream
L̇ᵞ = sum(L̇ʸ · X̂, dim=batch) # γ gradient = upstream weighted by normalized input
L̇ˣ is tricky because μ and σ depend on ALL x values.
Let's trace the wiggle:
If x[i,d] wiggles:
1. Direct effect on X̂[i,d]: ∂X̂[i,d]/∂x[i,d] = 1/σ
2. Indirect effect via μ: changing x[i,d] shifts μ, affects ALL X̂[:,d]
3. Indirect effect via σ: changing x[i,d] changes σ, affects ALL X̂[:,d]
Full derivative:
L̇ˣ̂ = L̇ʸ · γ # Gradient through scale
L̇σ = -sum(L̇ˣ̂ · (X-μ) / σ², dim=batch) # How σ wiggle affects loss
L̇μ = -sum(L̇ˣ̂ / σ, dim=batch) + L̇σ · (-2/B)·sum(X-μ, dim=batch)
L̇ˣ = L̇ˣ̂/σ + L̇σ·(2/B)·(X-μ)/σ + L̇μ/B
3.4 Simplified Form
After algebra, the BatchNorm backward becomes:
L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂) - X̂·mean(L̇ˣ̂·X̂))
Interpretation:
- Start with scaled upstream gradient
- Subtract its mean (center the gradient)
- Subtract correlation with normalized input (decorrelate)
"BatchNorm backward REDISTRIBUTES gradient to maintain zero-mean, unit-variance gradient flow."
3.5 Inference Mode
At inference, we don't have a batch. Use running averages from training:
μ_running = momentum·μ_running + (1-momentum)·μ_batch
σ_running = momentum·σ_running + (1-momentum)·σ_batch
Then normalize with running stats.
3.6 Problems with BatchNorm
- Batch size dependence: Small batches → noisy statistics
- Not suitable for sequence models: Each position needs different batch members
- Inference/training mismatch: Running stats ≠ batch stats
4. LayerNorm: Normalize Across Features
4.1 The Idea
For each sample, compute mean and variance ACROSS features.
Input: X of shape [B, D]
For each sample i:
μᵢ = (1/D) Σ_d X[i,d] # Mean across features
σᵢ = sqrt((1/D) Σ_d (X[i,d] - μᵢ)²)
X̂[i,:] = (X[i,:] - μᵢ) / σᵢ
Y[i,:] = γ · X̂[i,:] + β
4.2 Fluxion Forward Pass
μ = mean(X, dim=features) # Shape: [B]
σ = std(X, dim=features) # Shape: [B]
X̂ = (X - μ) / σ # Shape: [B, D]
Y = γ·X̂ + β # Shape: [B, D]
4.3 Key Difference from BatchNorm
BatchNorm: statistics across BATCH (each feature normalized independently)
LayerNorm: statistics across FEATURES (each sample normalized independently)
4.4 Why LayerNorm for Transformers?
- No batch dependence: Each token normalized independently
- Works with any batch size: Including batch=1 at inference
- Sequence-friendly: Position i doesn't need position j's statistics
4.5 Backward Pass
Same structure as BatchNorm, but sum over features instead of batch:
L̇ˣ̂ = L̇ʸ · γ
L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂, dim=features)
- X̂·mean(L̇ˣ̂·X̂, dim=features))
5. RMSNorm: Skip the Mean
5.1 The Simplification
LayerNorm computes mean AND variance. RMSNorm: "What if we skip the mean centering?"
RMS(x) = sqrt(mean(x²))
X̂ = X / RMS(X)
Y = γ · X̂
No β parameter (no shift), no mean subtraction.
5.2 Fluxion Forward Pass
rms = sqrt(mean(X², dim=features)) # Shape: [B]
X̂ = X / rms # Shape: [B, D]
Y = γ · X̂ # Shape: [B, D]
5.3 Why It Works
Empirically, the mean-centering in LayerNorm contributes little. RMSNorm achieves similar performance with:
- Fewer operations
- Simpler backward pass
- Better numerical stability
5.4 Backward Pass
Much simpler without mean:
L̇ˣ̂ = L̇ʸ · γ
L̇ʳᵐˢ = -sum(L̇ˣ̂ · X / rms²)
L̇ˣ = L̇ˣ̂/rms + L̇ʳᵐˢ · X/(D·rms)
Simplified:
L̇ˣ = (1/rms) · (L̇ˣ̂ - X̂·mean(L̇ˣ̂·X̂))
One fewer term than LayerNorm!
5.5 Usage
RMSNorm is used in:
- LLaMA
- Mistral
- Most modern LLMs
It's becoming the default for transformers.
6. Comparison Table
| Property | BatchNorm | LayerNorm | RMSNorm |
|---|---|---|---|
| Stats over | Batch | Features | Features |
| Learnable | γ, β | γ, β | γ only |
| Mean centering | Yes | Yes | No |
| Batch dependent | Yes | No | No |
| Inference mode | Running stats | Same as training | Same as training |
| Use case | CNNs | Transformers | Modern LLMs |
| Operations | Most | Medium | Fewest |
7. Pre-Norm vs Post-Norm
7.1 Post-Norm (Original Transformer)
X → Attention → Add(X) → LayerNorm → FFN → Add → LayerNorm → Output
Normalize AFTER residual connection.
7.2 Pre-Norm (Modern Default)
X → LayerNorm → Attention → Add(X) → LayerNorm → FFN → Add → Output
Normalize BEFORE attention/FFN.
7.3 Fluxion Analysis
Post-Norm gradient flow:
L̇ˣ = LayerNorm_backward(L̇ᵒᵘᵗ) # Gradient must flow through norm
Pre-Norm gradient flow:
L̇ˣ = L̇ᵒᵘᵗ + LayerNorm_backward(Attention_backward(L̇))
↑
Direct highway!
Pre-Norm has a direct gradient path that bypasses normalization. This stabilizes training for deep networks.
8. Numerical Stability
8.1 The Variance Problem
Computing variance naively:
var = mean(x²) - mean(x)²
If mean(x²) ≈ mean(x)², subtraction causes catastrophic cancellation.
8.2 Welford's Algorithm
Compute variance in a single pass, numerically stable:
def welford_var(x):
n = 0
mean = 0
M2 = 0
for xi in x:
n += 1
delta = xi - mean
mean += delta / n
delta2 = xi - mean
M2 += delta * delta2
return M2 / n
8.3 Fused Kernels
The fluxion view reveals that forward and backward are tightly coupled:
Forward needs: μ, σ, X̂ Backward needs: μ, σ, X̂ (same!)
Fused kernel can:
- Compute μ, σ in one pass
- Store only X̂ (derived from X, μ, σ)
- Backward reuses X̂ directly
This is why PyTorch's native LayerNorm is much faster than naive implementation.
9. Gradient Flow Analysis
9.1 Without Normalization
Deep network gradient flow:
L̇ˣ⁽⁰⁾ = W⁽¹⁾ᵀ · W⁽²⁾ᵀ · ... · W⁽ᴸ⁾ᵀ · L̇ʸ
If weights are slightly > 1: gradient explodes If weights are slightly < 1: gradient vanishes
9.2 With Normalization
Each layer's activations are forced to unit variance. Gradient magnitudes stabilize.
L̇ˣ̂ has unit variance (approximately)
→ L̇ˣ has controlled magnitude
→ No explosion or vanishing
9.3 The Jacobian View
LayerNorm Jacobian ∂Y/∂X is NOT diagonal (because of mean/var coupling).
But it has a special structure:
J = (1/σ) · (I - (1/D)·1·1ᵀ - (1/D)·X̂·X̂ᵀ)
This projects out the mean direction and decorrelates from X̂. Eigenvalues are bounded, preventing gradient explosion.
10. Implementation Details
10.1 Memory Layout Matters
LayerNorm over last dimension (features):
# Contiguous memory access pattern
for b in range(B):
for d in range(D): # Sequential access
...
BatchNorm over batch dimension:
# Strided memory access pattern
for d in range(D):
for b in range(B): # Jumping through memory
...
LayerNorm is more cache-friendly for typical tensor layouts [B, ..., D].
10.2 Epsilon Placement
# Wrong (can still divide by zero if var=0):
x_hat = (x - mean) / sqrt(var + eps)
# Right (always safe):
x_hat = (x - mean) / (sqrt(var) + eps)
# Also right (fused):
x_hat = (x - mean) * rsqrt(var + eps)
The rsqrt (reciprocal square root) is a single GPU instruction.
11. Summary
11.1 Fluxion View of Normalization
Forward:
μ̇, σ̇ computed from X
X̂ = (X - μ)/σ # Standardize
Y = γ·X̂ + β # Scale and shift
Backward:
L̇ᵞ = sum(L̇ʸ · X̂) # Scale gradient
L̇ᵝ = sum(L̇ʸ) # Shift gradient
L̇ˣ = redistributed gradient (centered, decorrelated)
11.2 Key Insight
Normalization doesn't just scale activations—it COUPLES gradient flow across the normalized dimension.
Each input's gradient depends on ALL other inputs in the normalization group.
This coupling:
- Stabilizes gradient magnitudes
- Prevents single features from dominating
- Enables deeper networks
References
- Ioffe & Szegedy (2015). "Batch Normalization: Accelerating Deep Network Training."
- Ba, Kiros & Hinton (2016). "Layer Normalization."
- Zhang & Sennrich (2019). "Root Mean Square Layer Normalization."
- Xiong et al. (2020). "On Layer Normalization in the Transformer Architecture."
Correspondence: scott@opentransformers.online