AmberLJC commited on
Commit
0d8aaba
·
verified ·
1 Parent(s): 3f891b2

Upload visualize_micro_world.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. visualize_micro_world.py +586 -0
visualize_micro_world.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Micro-World Visualization: Understanding Residual Connections
3
+
4
+ This script creates intuitive visualizations explaining:
5
+ 1. Signal flow through layers (forward pass)
6
+ 2. Gradient flow through layers (backward pass)
7
+ 3. The "gradient highway" effect of residual connections
8
+ 4. Layer-by-layer transformation visualization
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as mpatches
16
+ from matplotlib.patches import FancyArrowPatch, FancyBboxPatch
17
+ import json
18
+ import os
19
+
20
+ # Set seeds
21
+ torch.manual_seed(42)
22
+ np.random.seed(42)
23
+
24
+ # Load results from experiment
25
+ with open('results_fair.json', 'r') as f:
26
+ results = json.load(f)
27
+
28
+ os.makedirs('plots_micro', exist_ok=True)
29
+
30
+ # ============================================================
31
+ # VISUALIZATION 1: Signal Flow Diagram (Forward Pass)
32
+ # ============================================================
33
+ def plot_signal_flow():
34
+ """Visualize how signal magnitude changes through layers"""
35
+
36
+ fig, axes = plt.subplots(1, 2, figsize=(14, 8))
37
+
38
+ plain_stds = results['plain_mlp']['activation_stds']
39
+ res_stds = results['res_mlp']['activation_stds']
40
+
41
+ # Normalize for visualization (input signal = 1.0)
42
+ input_std = 0.577 # std of U(-1,1)
43
+ plain_signal = [input_std] + plain_stds
44
+ res_signal = [input_std] + res_stds
45
+
46
+ layers = range(len(plain_signal))
47
+
48
+ # Left plot: PlainMLP signal decay
49
+ ax = axes[0]
50
+ ax.set_title('PlainMLP: Signal DIES\n(No Residual Connection)', fontsize=14, fontweight='bold', color='#c0392b')
51
+
52
+ # Draw signal as decreasing bars
53
+ colors_plain = plt.cm.Reds(np.linspace(0.3, 0.9, len(plain_signal)))
54
+ bars = ax.bar(layers, plain_signal, color=colors_plain, edgecolor='darkred', linewidth=1.5)
55
+
56
+ ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12)
57
+ ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12)
58
+ ax.set_ylim(0, 0.7)
59
+
60
+ # Add annotation
61
+ ax.annotate('Signal\ncollapses!', xy=(15, 0.02), fontsize=12, color='darkred',
62
+ ha='center', fontweight='bold')
63
+ ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold')
64
+
65
+ # Right plot: ResMLP signal preservation
66
+ ax = axes[1]
67
+ ax.set_title('ResMLP: Signal PRESERVED\n(With Residual Connection)', fontsize=14, fontweight='bold', color='#2980b9')
68
+
69
+ colors_res = plt.cm.Blues(np.linspace(0.3, 0.9, len(res_signal)))
70
+ bars = ax.bar(layers, res_signal, color=colors_res, edgecolor='darkblue', linewidth=1.5)
71
+
72
+ ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12)
73
+ ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12)
74
+ ax.set_ylim(0, 0.7)
75
+
76
+ # Add annotation
77
+ ax.annotate('Signal stays\nhealthy!', xy=(15, 0.25), fontsize=12, color='darkblue',
78
+ ha='center', fontweight='bold')
79
+ ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold')
80
+
81
+ plt.tight_layout()
82
+ plt.savefig('plots_micro/1_signal_flow.png', dpi=150, bbox_inches='tight')
83
+ plt.close()
84
+ print("[Plot 1] Signal flow visualization saved")
85
+
86
+
87
+ # ============================================================
88
+ # VISUALIZATION 2: Gradient Flow Diagram (Backward Pass)
89
+ # ============================================================
90
+ def plot_gradient_flow():
91
+ """Visualize gradient magnitude through layers"""
92
+
93
+ fig, axes = plt.subplots(1, 2, figsize=(14, 8))
94
+
95
+ plain_grads = results['plain_mlp']['gradient_norms']
96
+ res_grads = results['res_mlp']['gradient_norms']
97
+
98
+ layers = range(1, 21)
99
+
100
+ # Left: PlainMLP gradient vanishing
101
+ ax = axes[0]
102
+ ax.set_title('PlainMLP: Gradients VANISH\n(Backward Pass)', fontsize=14, fontweight='bold', color='#c0392b')
103
+
104
+ # Use log scale bar chart
105
+ colors = plt.cm.Reds(np.linspace(0.9, 0.3, 20))
106
+ ax.bar(layers, plain_grads, color=colors, edgecolor='darkred', linewidth=1)
107
+ ax.set_yscale('log')
108
+ ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12)
109
+ ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12)
110
+ ax.set_ylim(1e-20, 1e-1)
111
+
112
+ # Annotations
113
+ ax.annotate(f'Layer 20:\n{plain_grads[-1]:.1e}', xy=(20, plain_grads[-1]),
114
+ xytext=(17, 1e-4), fontsize=10, color='darkred',
115
+ arrowprops=dict(arrowstyle='->', color='darkred'))
116
+ ax.annotate(f'Layer 1:\n{plain_grads[0]:.1e}\n(DEAD!)', xy=(1, max(plain_grads[0], 1e-20)),
117
+ xytext=(4, 1e-15), fontsize=10, color='darkred', fontweight='bold',
118
+ arrowprops=dict(arrowstyle='->', color='darkred'))
119
+
120
+ # Right: ResMLP healthy gradients
121
+ ax = axes[1]
122
+ ax.set_title('ResMLP: Gradients FLOW\n(Backward Pass)', fontsize=14, fontweight='bold', color='#2980b9')
123
+
124
+ colors = plt.cm.Blues(np.linspace(0.9, 0.3, 20))
125
+ ax.bar(layers, res_grads, color=colors, edgecolor='darkblue', linewidth=1)
126
+ ax.set_yscale('log')
127
+ ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12)
128
+ ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12)
129
+ ax.set_ylim(1e-20, 1e-1)
130
+
131
+ # Annotations
132
+ ax.annotate(f'Layer 20:\n{res_grads[-1]:.1e}', xy=(20, res_grads[-1]),
133
+ xytext=(17, 1e-4), fontsize=10, color='darkblue',
134
+ arrowprops=dict(arrowstyle='->', color='darkblue'))
135
+ ax.annotate(f'Layer 1:\n{res_grads[0]:.1e}\n(Healthy!)', xy=(1, res_grads[0]),
136
+ xytext=(4, 1e-4), fontsize=10, color='darkblue', fontweight='bold',
137
+ arrowprops=dict(arrowstyle='->', color='darkblue'))
138
+
139
+ plt.tight_layout()
140
+ plt.savefig('plots_micro/2_gradient_flow.png', dpi=150, bbox_inches='tight')
141
+ plt.close()
142
+ print("[Plot 2] Gradient flow visualization saved")
143
+
144
+
145
+ # ============================================================
146
+ # VISUALIZATION 3: The Residual "Highway" Concept
147
+ # ============================================================
148
+ def plot_highway_concept():
149
+ """Visual diagram showing the gradient highway concept"""
150
+
151
+ fig, axes = plt.subplots(2, 1, figsize=(14, 10))
152
+
153
+ # Top: PlainMLP - no highway
154
+ ax = axes[0]
155
+ ax.set_xlim(0, 12)
156
+ ax.set_ylim(0, 3)
157
+ ax.set_aspect('equal')
158
+ ax.axis('off')
159
+ ax.set_title('PlainMLP: Gradient Must Pass Through EVERY Layer\n(Like a winding mountain road)',
160
+ fontsize=14, fontweight='bold', color='#c0392b', pad=20)
161
+
162
+ # Draw layers as boxes
163
+ for i in range(6):
164
+ x = 1 + i * 1.8
165
+ box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05",
166
+ facecolor='#e74c3c', edgecolor='darkred', linewidth=2)
167
+ ax.add_patch(box)
168
+ ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11,
169
+ color='white', fontweight='bold')
170
+
171
+ # Draw arrows between layers (getting thinner = gradient vanishing)
172
+ if i < 5:
173
+ thickness = 3 * (0.5 ** i) # Exponential decay
174
+ alpha = max(0.2, 1 - i * 0.18)
175
+ ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5),
176
+ arrowprops=dict(arrowstyle='->', color='darkred',
177
+ lw=thickness, alpha=alpha))
178
+
179
+ # Add gradient flow label
180
+ ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkred')
181
+ ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkred')
182
+
183
+ # Add "vanishing" annotation
184
+ ax.annotate('Gradient shrinks\nat each layer!', xy=(8, 0.5), fontsize=11,
185
+ color='darkred', style='italic')
186
+
187
+ # Bottom: ResMLP - with highway
188
+ ax = axes[1]
189
+ ax.set_xlim(0, 12)
190
+ ax.set_ylim(0, 3.5)
191
+ ax.set_aspect('equal')
192
+ ax.axis('off')
193
+ ax.set_title('ResMLP: Gradient Has a Direct HIGHWAY\n(Skip connections = express lane)',
194
+ fontsize=14, fontweight='bold', color='#2980b9', pad=20)
195
+
196
+ # Draw the highway (skip connection) at top
197
+ ax.plot([1, 11], [2.8, 2.8], color='#27ae60', linewidth=6, alpha=0.8)
198
+ ax.annotate('', xy=(11, 2.8), xytext=(10.5, 2.8),
199
+ arrowprops=dict(arrowstyle='->', color='#27ae60', lw=3))
200
+ ax.text(6, 3.2, '✓ GRADIENT HIGHWAY (Identity Path)', ha='center', fontsize=12,
201
+ color='#27ae60', fontweight='bold')
202
+
203
+ # Draw layers as boxes
204
+ for i in range(6):
205
+ x = 1 + i * 1.8
206
+ box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05",
207
+ facecolor='#3498db', edgecolor='darkblue', linewidth=2)
208
+ ax.add_patch(box)
209
+ ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11,
210
+ color='white', fontweight='bold')
211
+
212
+ # Draw arrows between layers (constant thickness = gradient preserved)
213
+ if i < 5:
214
+ ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5),
215
+ arrowprops=dict(arrowstyle='->', color='darkblue', lw=2))
216
+
217
+ # Draw skip connections going up to highway
218
+ ax.plot([x + 0.6, x + 0.6], [2, 2.8], color='#27ae60', linewidth=2, alpha=0.5)
219
+
220
+ ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkblue')
221
+ ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkblue')
222
+
223
+ # Add explanation
224
+ ax.annotate('Gradient flows on highway\neven if layers block it!', xy=(8, 0.3),
225
+ fontsize=11, color='#27ae60', style='italic')
226
+
227
+ plt.tight_layout()
228
+ plt.savefig('plots_micro/3_highway_concept.png', dpi=150, bbox_inches='tight')
229
+ plt.close()
230
+ print("[Plot 3] Highway concept visualization saved")
231
+
232
+
233
+ # ============================================================
234
+ # VISUALIZATION 4: Mathematical View - Chain Rule
235
+ # ============================================================
236
+ def plot_chain_rule():
237
+ """Visualize the chain rule multiplication effect"""
238
+
239
+ fig, axes = plt.subplots(1, 2, figsize=(14, 7))
240
+
241
+ # Simulate gradient flow
242
+ num_layers = 20
243
+
244
+ # PlainMLP: gradient = product of layer gradients (each < 1)
245
+ plain_layer_grad = 0.7 # Each layer shrinks gradient by 0.7x
246
+ plain_cumulative = [1.0]
247
+ for i in range(num_layers):
248
+ plain_cumulative.append(plain_cumulative[-1] * plain_layer_grad)
249
+
250
+ # ResMLP: gradient = 1 + small_contribution (always >= 1 path)
251
+ res_layer_contrib = 0.05 # Small contribution from each layer
252
+ res_cumulative = [1.0]
253
+ for i in range(num_layers):
254
+ # The "1" from identity ensures gradient doesn't vanish
255
+ res_cumulative.append(res_cumulative[-1] * (1.0 + res_layer_contrib * (0.9 ** i)))
256
+
257
+ layers = range(num_layers + 1)
258
+
259
+ # Left: Show the multiplication effect
260
+ ax = axes[0]
261
+ ax.semilogy(layers, plain_cumulative, 'o-', color='#e74c3c', linewidth=2,
262
+ markersize=8, label='PlainMLP: 0.7 × 0.7 × 0.7 × ...')
263
+ ax.semilogy(layers, res_cumulative, 's-', color='#3498db', linewidth=2,
264
+ markersize=8, label='ResMLP: (1+ε) × (1+ε) × ...')
265
+
266
+ ax.set_xlabel('Layers Traversed (backward from loss)', fontsize=12)
267
+ ax.set_ylabel('Cumulative Gradient Scale (log)', fontsize=12)
268
+ ax.set_title('Chain Rule: Why Gradients Vanish\n(Multiplication Effect)', fontsize=14, fontweight='bold')
269
+ ax.legend(fontsize=11)
270
+ ax.grid(True, alpha=0.3)
271
+ ax.set_ylim(1e-8, 10)
272
+
273
+ # Add annotations
274
+ ax.annotate(f'After 20 layers:\n{plain_cumulative[-1]:.1e}',
275
+ xy=(20, plain_cumulative[-1]), xytext=(15, 1e-6),
276
+ fontsize=10, color='#c0392b',
277
+ arrowprops=dict(arrowstyle='->', color='#c0392b'))
278
+ ax.annotate(f'After 20 layers:\n{res_cumulative[-1]:.2f}',
279
+ xy=(20, res_cumulative[-1]), xytext=(15, 3),
280
+ fontsize=10, color='#2980b9',
281
+ arrowprops=dict(arrowstyle='->', color='#2980b9'))
282
+
283
+ # Right: Show the formula
284
+ ax = axes[1]
285
+ ax.axis('off')
286
+ ax.set_xlim(0, 10)
287
+ ax.set_ylim(0, 10)
288
+
289
+ ax.text(5, 9, 'The Math Behind It', fontsize=16, fontweight='bold',
290
+ ha='center', va='center')
291
+
292
+ # PlainMLP formula
293
+ ax.text(5, 7.5, 'PlainMLP Gradient:', fontsize=13, fontweight='bold',
294
+ ha='center', color='#c0392b')
295
+ ax.text(5, 6.5, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} \frac{\partial x_{i+1}}{\partial x_i}$',
296
+ fontsize=14, ha='center', color='#c0392b')
297
+ ax.text(5, 5.5, '= (small) × (small) × ... × (small) = TINY!',
298
+ fontsize=11, ha='center', color='#c0392b', style='italic')
299
+
300
+ # ResMLP formula
301
+ ax.text(5, 4, 'ResMLP Gradient:', fontsize=13, fontweight='bold',
302
+ ha='center', color='#2980b9')
303
+ ax.text(5, 3, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} (1 + \frac{\partial f_i}{\partial x_i})$',
304
+ fontsize=14, ha='center', color='#2980b9')
305
+ ax.text(5, 2, '= (1+ε) × (1+ε) × ... = PRESERVED!',
306
+ fontsize=11, ha='center', color='#2980b9', style='italic')
307
+
308
+ # Key insight
309
+ box = FancyBboxPatch((1, 0.3), 8, 1.2, boxstyle="round,pad=0.1",
310
+ facecolor='#f9e79f', edgecolor='#f39c12', linewidth=2)
311
+ ax.add_patch(box)
312
+ ax.text(5, 0.9, '💡 Key Insight: The "+x" in residual adds a "1" to each gradient term,\n'
313
+ 'preventing the product from shrinking to zero!',
314
+ fontsize=11, ha='center', va='center', fontweight='bold')
315
+
316
+ plt.tight_layout()
317
+ plt.savefig('plots_micro/4_chain_rule.png', dpi=150, bbox_inches='tight')
318
+ plt.close()
319
+ print("[Plot 4] Chain rule visualization saved")
320
+
321
+
322
+ # ============================================================
323
+ # VISUALIZATION 5: Layer-by-Layer Transformation
324
+ # ============================================================
325
+ def plot_layer_transformation():
326
+ """Show what happens to a single input vector through layers"""
327
+
328
+ # Create simple models for visualization
329
+ class PlainMLP(nn.Module):
330
+ def __init__(self, dim, num_layers):
331
+ super().__init__()
332
+ self.layers = nn.ModuleList()
333
+ for _ in range(num_layers):
334
+ layer = nn.Linear(dim, dim)
335
+ nn.init.kaiming_normal_(layer.weight)
336
+ layer.weight.data *= 1.0 / np.sqrt(num_layers)
337
+ nn.init.zeros_(layer.bias)
338
+ self.layers.append(layer)
339
+ self.activation = nn.ReLU()
340
+
341
+ def forward_with_intermediates(self, x):
342
+ intermediates = [x.clone()]
343
+ for layer in self.layers:
344
+ x = self.activation(layer(x))
345
+ intermediates.append(x.clone())
346
+ return intermediates
347
+
348
+ class ResMLP(nn.Module):
349
+ def __init__(self, dim, num_layers):
350
+ super().__init__()
351
+ self.layers = nn.ModuleList()
352
+ for _ in range(num_layers):
353
+ layer = nn.Linear(dim, dim)
354
+ nn.init.kaiming_normal_(layer.weight)
355
+ layer.weight.data *= 1.0 / np.sqrt(num_layers)
356
+ nn.init.zeros_(layer.bias)
357
+ self.layers.append(layer)
358
+ self.activation = nn.ReLU()
359
+
360
+ def forward_with_intermediates(self, x):
361
+ intermediates = [x.clone()]
362
+ for layer in self.layers:
363
+ x = x + self.activation(layer(x))
364
+ intermediates.append(x.clone())
365
+ return intermediates
366
+
367
+ # Create models
368
+ dim = 64
369
+ num_layers = 20
370
+ plain = PlainMLP(dim, num_layers)
371
+ res = ResMLP(dim, num_layers)
372
+
373
+ # Single input vector
374
+ x = torch.randn(1, dim) * 0.5
375
+
376
+ # Get intermediates
377
+ plain_ints = plain.forward_with_intermediates(x)
378
+ res_ints = res.forward_with_intermediates(x)
379
+
380
+ # Extract norms and first 2 dimensions for visualization
381
+ plain_norms = [p.norm().item() for p in plain_ints]
382
+ res_norms = [r.norm().item() for r in res_ints]
383
+
384
+ plain_2d = [p[0, :2].detach().numpy() for p in plain_ints]
385
+ res_2d = [r[0, :2].detach().numpy() for r in res_ints]
386
+
387
+ fig, axes = plt.subplots(2, 2, figsize=(14, 12))
388
+
389
+ # Top left: Vector magnitude through layers
390
+ ax = axes[0, 0]
391
+ layers = range(len(plain_norms))
392
+ ax.plot(layers, plain_norms, 'o-', color='#e74c3c', linewidth=2, markersize=6, label='PlainMLP')
393
+ ax.plot(layers, res_norms, 's-', color='#3498db', linewidth=2, markersize=6, label='ResMLP')
394
+ ax.set_xlabel('Layer (0=Input)', fontsize=12)
395
+ ax.set_ylabel('Vector Magnitude (L2 norm)', fontsize=12)
396
+ ax.set_title('Signal Magnitude Through Network', fontsize=13, fontweight='bold')
397
+ ax.legend()
398
+ ax.grid(True, alpha=0.3)
399
+
400
+ # Top right: 2D trajectory visualization
401
+ ax = axes[0, 1]
402
+
403
+ # PlainMLP trajectory
404
+ plain_x = [p[0] for p in plain_2d]
405
+ plain_y = [p[1] for p in plain_2d]
406
+ ax.plot(plain_x, plain_y, 'o-', color='#e74c3c', linewidth=1.5, markersize=4,
407
+ alpha=0.7, label='PlainMLP path')
408
+ ax.scatter(plain_x[0], plain_y[0], s=100, color='#e74c3c', marker='*', zorder=5)
409
+ ax.scatter(plain_x[-1], plain_y[-1], s=100, color='#e74c3c', marker='X', zorder=5)
410
+
411
+ # ResMLP trajectory
412
+ res_x = [r[0] for r in res_2d]
413
+ res_y = [r[1] for r in res_2d]
414
+ ax.plot(res_x, res_y, 's-', color='#3498db', linewidth=1.5, markersize=4,
415
+ alpha=0.7, label='ResMLP path')
416
+ ax.scatter(res_x[0], res_y[0], s=100, color='#3498db', marker='*', zorder=5)
417
+ ax.scatter(res_x[-1], res_y[-1], s=100, color='#3498db', marker='X', zorder=5)
418
+
419
+ ax.set_xlabel('Dimension 1', fontsize=12)
420
+ ax.set_ylabel('Dimension 2', fontsize=12)
421
+ ax.set_title('2D Projection of Vector Path\n(★=start, ✕=end)', fontsize=13, fontweight='bold')
422
+ ax.legend()
423
+ ax.grid(True, alpha=0.3)
424
+ ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3)
425
+ ax.axvline(x=0, color='gray', linestyle='-', alpha=0.3)
426
+
427
+ # Bottom left: PlainMLP heatmap of activations
428
+ ax = axes[1, 0]
429
+ plain_acts = np.array([p[0, :32].detach().numpy() for p in plain_ints]) # First 32 dims
430
+ im = ax.imshow(plain_acts.T, aspect='auto', cmap='Reds', interpolation='nearest')
431
+ ax.set_xlabel('Layer', fontsize=12)
432
+ ax.set_ylabel('Dimension (first 32)', fontsize=12)
433
+ ax.set_title('PlainMLP: Activations Die Out', fontsize=13, fontweight='bold', color='#c0392b')
434
+ plt.colorbar(im, ax=ax, label='Activation Value')
435
+
436
+ # Bottom right: ResMLP heatmap of activations
437
+ ax = axes[1, 1]
438
+ res_acts = np.array([r[0, :32].detach().numpy() for r in res_ints]) # First 32 dims
439
+ im = ax.imshow(res_acts.T, aspect='auto', cmap='Blues', interpolation='nearest')
440
+ ax.set_xlabel('Layer', fontsize=12)
441
+ ax.set_ylabel('Dimension (first 32)', fontsize=12)
442
+ ax.set_title('ResMLP: Activations Stay Alive', fontsize=13, fontweight='bold', color='#2980b9')
443
+ plt.colorbar(im, ax=ax, label='Activation Value')
444
+
445
+ plt.tight_layout()
446
+ plt.savefig('plots_micro/5_layer_transformation.png', dpi=150, bbox_inches='tight')
447
+ plt.close()
448
+ print("[Plot 5] Layer transformation visualization saved")
449
+
450
+
451
+ # ============================================================
452
+ # VISUALIZATION 6: Before/After Training Comparison
453
+ # ============================================================
454
+ def plot_learning_comparison():
455
+ """Show what each model learned (or didn't learn)"""
456
+
457
+ fig, axes = plt.subplots(2, 2, figsize=(14, 12))
458
+
459
+ plain_losses = results['plain_mlp']['loss_history']
460
+ res_losses = results['res_mlp']['loss_history']
461
+
462
+ # Top left: Loss curves with annotations
463
+ ax = axes[0, 0]
464
+ steps = range(len(plain_losses))
465
+ ax.plot(steps, plain_losses, color='#e74c3c', linewidth=2, label='PlainMLP')
466
+ ax.plot(steps, res_losses, color='#3498db', linewidth=2, label='ResMLP')
467
+ ax.set_xlabel('Training Steps', fontsize=12)
468
+ ax.set_ylabel('MSE Loss', fontsize=12)
469
+ ax.set_title('Learning Progress', fontsize=13, fontweight='bold')
470
+ ax.set_yscale('log')
471
+ ax.legend()
472
+ ax.grid(True, alpha=0.3)
473
+
474
+ # Add phase annotations
475
+ ax.axvspan(0, 50, alpha=0.1, color='gray')
476
+ ax.text(25, 5, 'Early\nTraining', ha='center', fontsize=9, color='gray')
477
+ ax.axvspan(450, 500, alpha=0.1, color='green')
478
+ ax.text(475, 5, 'Final', ha='center', fontsize=9, color='gray')
479
+
480
+ # Top right: Loss reduction bar chart
481
+ ax = axes[0, 1]
482
+
483
+ plain_initial = plain_losses[0]
484
+ plain_final = plain_losses[-1]
485
+ res_initial = res_losses[0]
486
+ res_final = res_losses[-1]
487
+
488
+ plain_reduction = (1 - plain_final / plain_initial) * 100
489
+ res_reduction = (1 - res_final / res_initial) * 100
490
+
491
+ bars = ax.bar(['PlainMLP', 'ResMLP'], [plain_reduction, res_reduction],
492
+ color=['#e74c3c', '#3498db'], edgecolor='black', linewidth=2)
493
+ ax.set_ylabel('Loss Reduction (%)', fontsize=12)
494
+ ax.set_title('How Much Did Each Model Learn?', fontsize=13, fontweight='bold')
495
+ ax.set_ylim(0, 110)
496
+
497
+ # Add value labels
498
+ ax.text(0, plain_reduction + 3, f'{plain_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold')
499
+ ax.text(1, res_reduction + 3, f'{res_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold')
500
+
501
+ # Add verdict
502
+ ax.text(0, plain_reduction/2, 'FAILED\nTO LEARN', ha='center', va='center',
503
+ fontsize=11, color='white', fontweight='bold')
504
+ ax.text(1, res_reduction/2, 'LEARNED\nSUCCESSFULLY', ha='center', va='center',
505
+ fontsize=11, color='white', fontweight='bold')
506
+
507
+ # Bottom: Gradient comparison at different training stages
508
+ ax = axes[1, 0]
509
+
510
+ plain_grads = results['plain_mlp']['gradient_norms']
511
+ res_grads = results['res_mlp']['gradient_norms']
512
+
513
+ layers = range(1, 21)
514
+ width = 0.35
515
+
516
+ ax.bar([l - width/2 for l in layers], plain_grads, width, label='PlainMLP',
517
+ color='#e74c3c', alpha=0.8)
518
+ ax.bar([l + width/2 for l in layers], res_grads, width, label='ResMLP',
519
+ color='#3498db', alpha=0.8)
520
+
521
+ ax.set_xlabel('Layer', fontsize=12)
522
+ ax.set_ylabel('Gradient Magnitude', fontsize=12)
523
+ ax.set_title('Final Gradient Distribution by Layer', fontsize=13, fontweight='bold')
524
+ ax.set_yscale('log')
525
+ ax.legend()
526
+ ax.grid(True, alpha=0.3, axis='y')
527
+
528
+ # Bottom right: Summary diagram
529
+ ax = axes[1, 1]
530
+ ax.axis('off')
531
+ ax.set_xlim(0, 10)
532
+ ax.set_ylim(0, 10)
533
+
534
+ ax.text(5, 9.5, '📊 Summary: Why Residuals Work', fontsize=16, fontweight='bold', ha='center')
535
+
536
+ # PlainMLP box
537
+ box1 = FancyBboxPatch((0.5, 5), 4, 3.5, boxstyle="round,pad=0.1",
538
+ facecolor='#fadbd8', edgecolor='#c0392b', linewidth=2)
539
+ ax.add_patch(box1)
540
+ ax.text(2.5, 8, 'PlainMLP ❌', fontsize=13, fontweight='bold', ha='center', color='#c0392b')
541
+ ax.text(2.5, 7, f'• Loss: {plain_final:.3f}', fontsize=11, ha='center')
542
+ ax.text(2.5, 6.3, f'• Gradient L1: {plain_grads[0]:.1e}', fontsize=11, ha='center')
543
+ ax.text(2.5, 5.6, '• Status: UNTRAINABLE', fontsize=11, ha='center', color='#c0392b')
544
+
545
+ # ResMLP box
546
+ box2 = FancyBboxPatch((5.5, 5), 4, 3.5, boxstyle="round,pad=0.1",
547
+ facecolor='#d4e6f1', edgecolor='#2980b9', linewidth=2)
548
+ ax.add_patch(box2)
549
+ ax.text(7.5, 8, 'ResMLP ✓', fontsize=13, fontweight='bold', ha='center', color='#2980b9')
550
+ ax.text(7.5, 7, f'• Loss: {res_final:.3f}', fontsize=11, ha='center')
551
+ ax.text(7.5, 6.3, f'• Gradient L1: {res_grads[0]:.1e}', fontsize=11, ha='center')
552
+ ax.text(7.5, 5.6, '• Status: TRAINED', fontsize=11, ha='center', color='#2980b9')
553
+
554
+ # Key insight box
555
+ box3 = FancyBboxPatch((1, 0.5), 8, 3.5, boxstyle="round,pad=0.1",
556
+ facecolor='#fef9e7', edgecolor='#f39c12', linewidth=2)
557
+ ax.add_patch(box3)
558
+ ax.text(5, 3.5, '💡 The Residual Connection:', fontsize=13, fontweight='bold', ha='center')
559
+ ax.text(5, 2.6, '1. Creates a "gradient highway" for backpropagation', fontsize=11, ha='center')
560
+ ax.text(5, 1.9, '2. Preserves signal magnitude through forward pass', fontsize=11, ha='center')
561
+ ax.text(5, 1.2, '3. Allows training of very deep networks', fontsize=11, ha='center')
562
+
563
+ plt.tight_layout()
564
+ plt.savefig('plots_micro/6_learning_comparison.png', dpi=150, bbox_inches='tight')
565
+ plt.close()
566
+ print("[Plot 6] Learning comparison visualization saved")
567
+
568
+
569
+ # ============================================================
570
+ # MAIN
571
+ # ============================================================
572
+ if __name__ == "__main__":
573
+ print("=" * 60)
574
+ print("Creating Micro-World Visualizations")
575
+ print("=" * 60)
576
+
577
+ plot_signal_flow()
578
+ plot_gradient_flow()
579
+ plot_highway_concept()
580
+ plot_chain_rule()
581
+ plot_layer_transformation()
582
+ plot_learning_comparison()
583
+
584
+ print("\n" + "=" * 60)
585
+ print("All visualizations saved to plots_micro/")
586
+ print("=" * 60)