File size: 11,506 Bytes
021165c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
# Normalization Layers via the Method of Fluxions
## BatchNorm, LayerNorm, RMSNorm: What They Actually Do

**Scott Bisset, Silicon Goddess**  
OpenTransformers Ltd  
January 2026

---

## Abstract

Normalization layers (BatchNorm, LayerNorm, RMSNorm) are presented in textbooks as "normalize then scale and shift" with formulas involving means and variances. This obscures their true purpose and makes backward pass derivation seem magical. We reformulate normalization using fluxions, revealing: (1) normalization as signal conditioning, (2) the backward pass as sensitivity redistribution, and (3) why different norms suit different architectures. The fluxion view also explains the computational structure that enables fused kernels.

---

## 1. Why Normalize?

### 1.1 The Problem

Deep networks suffer from internal covariate shift:
- Each layer's input distribution changes during training
- Later layers constantly adapt to moving targets
- Training becomes unstable

### 1.2 The Solution Intuition

Force each layer's inputs to have consistent statistics.
"Standardize the signal before processing it."

---

## 2. The General Normalization Framework

### 2.1 Forward Pass

All normalization layers follow:

```
1. Compute statistics (μ, σ) over some dimension
2. Normalize: x̂ = (x - μ) / σ
3. Scale and shift: y = γ·x̂ + β
```

What differs is WHICH dimensions we compute statistics over.

### 2.2 Fluxion View

Let x be the input signal.

```
μ = mean(x)           # Signal center
σ = std(x)            # Signal spread
x̂ = (x - μ) / σ       # Centered and scaled to unit variance
y = γ·x̂ + β           # Learnable rescaling
```

**γ** (gamma) = learned amplitude
**β** (beta) = learned offset

Without γ and β, normalization would constrain representational power.
With them, the network can learn to undo normalization if needed.

---

## 3. BatchNorm: Normalize Across Batch

### 3.1 The Idea

For each feature/channel, compute mean and variance ACROSS the batch.

```
Input: X of shape [B, D]  (B samples, D features)

For each feature d:
    μ_d = (1/B) Σᵢ X[i,d]        # Mean of feature d across batch
    σ_d = sqrt((1/B) Σᵢ (X[i,d] - μ_d)²)  # Std of feature d
    
    X̂[:,d] = (X[:,d] - μ_d) / σ_d
    Y[:,d] = γ_d · X̂[:,d] + β_d
```

### 3.2 Fluxion Forward Pass

```
μ = mean(X, dim=batch)          # Shape: [D]
σ = std(X, dim=batch)           # Shape: [D]
X̂ = (X - μ) / σ                 # Shape: [B, D]
Y = γ·X̂ + β                     # Shape: [B, D]
```

### 3.3 The Backward Pass (Fluxion Derivation)

Given L̇ʸ (upstream gradient), find L̇ˣ, L̇ᵞ, L̇ᵝ.

**Easy ones first:**

```
L̇ᵝ = sum(L̇ʸ, dim=batch)         # β gradient = sum of upstream
L̇ᵞ = sum(L̇ʸ · X̂, dim=batch)     # γ gradient = upstream weighted by normalized input
```

**L̇ˣ is tricky because μ and σ depend on ALL x values.**

Let's trace the wiggle:

```
If x[i,d] wiggles:
    1. Direct effect on X̂[i,d]: ∂X̂[i,d]/∂x[i,d] = 1/σ
    2. Indirect effect via μ: changing x[i,d] shifts μ, affects ALL X̂[:,d]  
    3. Indirect effect via σ: changing x[i,d] changes σ, affects ALL X̂[:,d]
```

Full derivative:

```
L̇ˣ̂ = L̇ʸ · γ                     # Gradient through scale

L̇σ = -sum(L̇ˣ̂ · (X-μ) / σ², dim=batch)   # How σ wiggle affects loss

L̇μ = -sum(L̇ˣ̂ / σ, dim=batch) + L̇σ · (-2/B)·sum(X-μ, dim=batch)

L̇ˣ = L̇ˣ̂/σ + L̇σ·(2/B)·(X-μ)/σ + L̇μ/B
```

### 3.4 Simplified Form

After algebra, the BatchNorm backward becomes:

```
L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂) - X̂·mean(L̇ˣ̂·X̂))
```

**Interpretation:**
- Start with scaled upstream gradient
- Subtract its mean (center the gradient)  
- Subtract correlation with normalized input (decorrelate)

"BatchNorm backward REDISTRIBUTES gradient to maintain zero-mean, unit-variance gradient flow."

### 3.5 Inference Mode

At inference, we don't have a batch. Use running averages from training:

```
μ_running = momentum·μ_running + (1-momentum)·μ_batch
σ_running = momentum·σ_running + (1-momentum)·σ_batch
```

Then normalize with running stats.

### 3.6 Problems with BatchNorm

1. **Batch size dependence**: Small batches → noisy statistics
2. **Not suitable for sequence models**: Each position needs different batch members
3. **Inference/training mismatch**: Running stats ≠ batch stats

---

## 4. LayerNorm: Normalize Across Features

### 4.1 The Idea

For each sample, compute mean and variance ACROSS features.

```
Input: X of shape [B, D]

For each sample i:
    μᵢ = (1/D) Σ_d X[i,d]        # Mean across features
    σᵢ = sqrt((1/D) Σ_d (X[i,d] - μᵢ)²)
    
    X̂[i,:] = (X[i,:] - μᵢ) / σᵢ
    Y[i,:] = γ · X̂[i,:] + β
```

### 4.2 Fluxion Forward Pass

```
μ = mean(X, dim=features)       # Shape: [B]
σ = std(X, dim=features)        # Shape: [B]
X̂ = (X - μ) / σ                 # Shape: [B, D]
Y = γ·X̂ + β                     # Shape: [B, D]
```

### 4.3 Key Difference from BatchNorm

```
BatchNorm: statistics across BATCH (each feature normalized independently)
LayerNorm: statistics across FEATURES (each sample normalized independently)
```

### 4.4 Why LayerNorm for Transformers?

1. **No batch dependence**: Each token normalized independently
2. **Works with any batch size**: Including batch=1 at inference
3. **Sequence-friendly**: Position i doesn't need position j's statistics

### 4.5 Backward Pass

Same structure as BatchNorm, but sum over features instead of batch:

```
L̇ˣ̂ = L̇ʸ · γ

L̇ˣ = (1/σ) · (L̇ˣ̂ - mean(L̇ˣ̂, dim=features) 
              - X̂·mean(L̇ˣ̂·X̂, dim=features))
```

---

## 5. RMSNorm: Skip the Mean

### 5.1 The Simplification

LayerNorm computes mean AND variance.
RMSNorm: "What if we skip the mean centering?"

```
RMS(x) = sqrt(mean(x²))
X̂ = X / RMS(X)
Y = γ · X̂
```

No β parameter (no shift), no mean subtraction.

### 5.2 Fluxion Forward Pass

```
rms = sqrt(mean(X², dim=features))   # Shape: [B]
X̂ = X / rms                          # Shape: [B, D]
Y = γ · X̂                            # Shape: [B, D]
```

### 5.3 Why It Works

Empirically, the mean-centering in LayerNorm contributes little.
RMSNorm achieves similar performance with:
- Fewer operations
- Simpler backward pass
- Better numerical stability

### 5.4 Backward Pass

Much simpler without mean:

```
L̇ˣ̂ = L̇ʸ · γ
L̇ʳᵐˢ = -sum(L̇ˣ̂ · X / rms²)
L̇ˣ = L̇ˣ̂/rms + L̇ʳᵐˢ · X/(D·rms)
```

Simplified:
```
L̇ˣ = (1/rms) · (L̇ˣ̂ - X̂·mean(L̇ˣ̂·X̂))
```

One fewer term than LayerNorm!

### 5.5 Usage

RMSNorm is used in:
- LLaMA
- Mistral
- Most modern LLMs

It's becoming the default for transformers.

---

## 6. Comparison Table

| Property | BatchNorm | LayerNorm | RMSNorm |
|----------|-----------|-----------|---------|
| Stats over | Batch | Features | Features |
| Learnable | γ, β | γ, β | γ only |
| Mean centering | Yes | Yes | No |
| Batch dependent | Yes | No | No |
| Inference mode | Running stats | Same as training | Same as training |
| Use case | CNNs | Transformers | Modern LLMs |
| Operations | Most | Medium | Fewest |

---

## 7. Pre-Norm vs Post-Norm

### 7.1 Post-Norm (Original Transformer)

```
X → Attention → Add(X) → LayerNorm → FFN → Add → LayerNorm → Output
```

Normalize AFTER residual connection.

### 7.2 Pre-Norm (Modern Default)

```
X → LayerNorm → Attention → Add(X) → LayerNorm → FFN → Add → Output
```

Normalize BEFORE attention/FFN.

### 7.3 Fluxion Analysis

**Post-Norm gradient flow:**
```
L̇ˣ = LayerNorm_backward(L̇ᵒᵘᵗ)   # Gradient must flow through norm
```

**Pre-Norm gradient flow:**
```
L̇ˣ = L̇ᵒᵘᵗ + LayerNorm_backward(Attention_backward(L̇))

    Direct highway!
```

Pre-Norm has a direct gradient path that bypasses normalization.
This stabilizes training for deep networks.

---

## 8. Numerical Stability

### 8.1 The Variance Problem

Computing variance naively:
```
var = mean(x²) - mean(x)²
```

If mean(x²) ≈ mean(x)², subtraction causes catastrophic cancellation.

### 8.2 Welford's Algorithm

Compute variance in a single pass, numerically stable:

```python
def welford_var(x):
    n = 0
    mean = 0
    M2 = 0
    for xi in x:
        n += 1
        delta = xi - mean
        mean += delta / n
        delta2 = xi - mean
        M2 += delta * delta2
    return M2 / n
```

### 8.3 Fused Kernels

The fluxion view reveals that forward and backward are tightly coupled:

Forward needs: μ, σ, X̂
Backward needs: μ, σ, X̂ (same!)

Fused kernel can:
1. Compute μ, σ in one pass
2. Store only X̂ (derived from X, μ, σ)
3. Backward reuses X̂ directly

This is why PyTorch's native LayerNorm is much faster than naive implementation.

---

## 9. Gradient Flow Analysis

### 9.1 Without Normalization

Deep network gradient flow:
```
L̇ˣ⁽⁰⁾ = W⁽¹⁾ᵀ · W⁽²⁾ᵀ · ... · W⁽ᴸ⁾ᵀ · L̇ʸ
```

If weights are slightly > 1: gradient explodes
If weights are slightly < 1: gradient vanishes

### 9.2 With Normalization

Each layer's activations are forced to unit variance.
Gradient magnitudes stabilize.

```
L̇ˣ̂ has unit variance (approximately)
→ L̇ˣ has controlled magnitude
→ No explosion or vanishing
```

### 9.3 The Jacobian View

LayerNorm Jacobian ∂Y/∂X is NOT diagonal (because of mean/var coupling).

But it has a special structure:
```
J = (1/σ) · (I - (1/D)·1·1ᵀ - (1/D)·X̂·X̂ᵀ)
```

This projects out the mean direction and decorrelates from X̂.
Eigenvalues are bounded, preventing gradient explosion.

---

## 10. Implementation Details

### 10.1 Memory Layout Matters

LayerNorm over last dimension (features):
```python
# Contiguous memory access pattern
for b in range(B):
    for d in range(D):  # Sequential access
        ...
```

BatchNorm over batch dimension:
```python
# Strided memory access pattern  
for d in range(D):
    for b in range(B):  # Jumping through memory
        ...
```

LayerNorm is more cache-friendly for typical tensor layouts [B, ..., D].

### 10.2 Epsilon Placement

```python
# Wrong (can still divide by zero if var=0):
x_hat = (x - mean) / sqrt(var + eps)

# Right (always safe):
x_hat = (x - mean) / (sqrt(var) + eps)

# Also right (fused):
x_hat = (x - mean) * rsqrt(var + eps)
```

The `rsqrt` (reciprocal square root) is a single GPU instruction.

---

## 11. Summary

### 11.1 Fluxion View of Normalization

**Forward:**
```
μ̇, σ̇ computed from X
X̂ = (X - μ)/σ       # Standardize
Y = γ·X̂ + β          # Scale and shift
```

**Backward:**
```
L̇ᵞ = sum(L̇ʸ · X̂)    # Scale gradient
L̇ᵝ = sum(L̇ʸ)         # Shift gradient
L̇ˣ = redistributed gradient (centered, decorrelated)
```

### 11.2 Key Insight

Normalization doesn't just scale activations—it COUPLES gradient flow across the normalized dimension.

Each input's gradient depends on ALL other inputs in the normalization group.

This coupling:
- Stabilizes gradient magnitudes
- Prevents single features from dominating
- Enables deeper networks

---

## References

1. Ioffe & Szegedy (2015). "Batch Normalization: Accelerating Deep Network Training."
2. Ba, Kiros & Hinton (2016). "Layer Normalization."
3. Zhang & Sennrich (2019). "Root Mean Square Layer Normalization."
4. Xiong et al. (2020). "On Layer Normalization in the Transformer Architecture."

---

*Correspondence: scott@opentransformers.online*