|
|
""" |
|
|
Micro-World Visualization: Understanding Residual Connections |
|
|
|
|
|
This script creates intuitive visualizations explaining: |
|
|
1. Signal flow through layers (forward pass) |
|
|
2. Gradient flow through layers (backward pass) |
|
|
3. The "gradient highway" effect of residual connections |
|
|
4. Layer-by-layer transformation visualization |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as mpatches |
|
|
from matplotlib.patches import FancyArrowPatch, FancyBboxPatch |
|
|
import json |
|
|
import os |
|
|
|
|
|
|
|
|
torch.manual_seed(42) |
|
|
np.random.seed(42) |
|
|
|
|
|
|
|
|
with open('results_fair.json', 'r') as f: |
|
|
results = json.load(f) |
|
|
|
|
|
os.makedirs('plots_micro', exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_signal_flow(): |
|
|
"""Visualize how signal magnitude changes through layers""" |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 8)) |
|
|
|
|
|
plain_stds = results['plain_mlp']['activation_stds'] |
|
|
res_stds = results['res_mlp']['activation_stds'] |
|
|
|
|
|
|
|
|
input_std = 0.577 |
|
|
plain_signal = [input_std] + plain_stds |
|
|
res_signal = [input_std] + res_stds |
|
|
|
|
|
layers = range(len(plain_signal)) |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
ax.set_title('PlainMLP: Signal DIES\n(No Residual Connection)', fontsize=14, fontweight='bold', color='#c0392b') |
|
|
|
|
|
|
|
|
colors_plain = plt.cm.Reds(np.linspace(0.3, 0.9, len(plain_signal))) |
|
|
bars = ax.bar(layers, plain_signal, color=colors_plain, edgecolor='darkred', linewidth=1.5) |
|
|
|
|
|
ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12) |
|
|
ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12) |
|
|
ax.set_ylim(0, 0.7) |
|
|
|
|
|
|
|
|
ax.annotate('Signal\ncollapses!', xy=(15, 0.02), fontsize=12, color='darkred', |
|
|
ha='center', fontweight='bold') |
|
|
ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold') |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.set_title('ResMLP: Signal PRESERVED\n(With Residual Connection)', fontsize=14, fontweight='bold', color='#2980b9') |
|
|
|
|
|
colors_res = plt.cm.Blues(np.linspace(0.3, 0.9, len(res_signal))) |
|
|
bars = ax.bar(layers, res_signal, color=colors_res, edgecolor='darkblue', linewidth=1.5) |
|
|
|
|
|
ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12) |
|
|
ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12) |
|
|
ax.set_ylim(0, 0.7) |
|
|
|
|
|
|
|
|
ax.annotate('Signal stays\nhealthy!', xy=(15, 0.25), fontsize=12, color='darkblue', |
|
|
ha='center', fontweight='bold') |
|
|
ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/1_signal_flow.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 1] Signal flow visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_gradient_flow(): |
|
|
"""Visualize gradient magnitude through layers""" |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 8)) |
|
|
|
|
|
plain_grads = results['plain_mlp']['gradient_norms'] |
|
|
res_grads = results['res_mlp']['gradient_norms'] |
|
|
|
|
|
layers = range(1, 21) |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
ax.set_title('PlainMLP: Gradients VANISH\n(Backward Pass)', fontsize=14, fontweight='bold', color='#c0392b') |
|
|
|
|
|
|
|
|
colors = plt.cm.Reds(np.linspace(0.9, 0.3, 20)) |
|
|
ax.bar(layers, plain_grads, color=colors, edgecolor='darkred', linewidth=1) |
|
|
ax.set_yscale('log') |
|
|
ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12) |
|
|
ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12) |
|
|
ax.set_ylim(1e-20, 1e-1) |
|
|
|
|
|
|
|
|
ax.annotate(f'Layer 20:\n{plain_grads[-1]:.1e}', xy=(20, plain_grads[-1]), |
|
|
xytext=(17, 1e-4), fontsize=10, color='darkred', |
|
|
arrowprops=dict(arrowstyle='->', color='darkred')) |
|
|
ax.annotate(f'Layer 1:\n{plain_grads[0]:.1e}\n(DEAD!)', xy=(1, max(plain_grads[0], 1e-20)), |
|
|
xytext=(4, 1e-15), fontsize=10, color='darkred', fontweight='bold', |
|
|
arrowprops=dict(arrowstyle='->', color='darkred')) |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.set_title('ResMLP: Gradients FLOW\n(Backward Pass)', fontsize=14, fontweight='bold', color='#2980b9') |
|
|
|
|
|
colors = plt.cm.Blues(np.linspace(0.9, 0.3, 20)) |
|
|
ax.bar(layers, res_grads, color=colors, edgecolor='darkblue', linewidth=1) |
|
|
ax.set_yscale('log') |
|
|
ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12) |
|
|
ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12) |
|
|
ax.set_ylim(1e-20, 1e-1) |
|
|
|
|
|
|
|
|
ax.annotate(f'Layer 20:\n{res_grads[-1]:.1e}', xy=(20, res_grads[-1]), |
|
|
xytext=(17, 1e-4), fontsize=10, color='darkblue', |
|
|
arrowprops=dict(arrowstyle='->', color='darkblue')) |
|
|
ax.annotate(f'Layer 1:\n{res_grads[0]:.1e}\n(Healthy!)', xy=(1, res_grads[0]), |
|
|
xytext=(4, 1e-4), fontsize=10, color='darkblue', fontweight='bold', |
|
|
arrowprops=dict(arrowstyle='->', color='darkblue')) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/2_gradient_flow.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 2] Gradient flow visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_highway_concept(): |
|
|
"""Visual diagram showing the gradient highway concept""" |
|
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(14, 10)) |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
ax.set_xlim(0, 12) |
|
|
ax.set_ylim(0, 3) |
|
|
ax.set_aspect('equal') |
|
|
ax.axis('off') |
|
|
ax.set_title('PlainMLP: Gradient Must Pass Through EVERY Layer\n(Like a winding mountain road)', |
|
|
fontsize=14, fontweight='bold', color='#c0392b', pad=20) |
|
|
|
|
|
|
|
|
for i in range(6): |
|
|
x = 1 + i * 1.8 |
|
|
box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05", |
|
|
facecolor='#e74c3c', edgecolor='darkred', linewidth=2) |
|
|
ax.add_patch(box) |
|
|
ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11, |
|
|
color='white', fontweight='bold') |
|
|
|
|
|
|
|
|
if i < 5: |
|
|
thickness = 3 * (0.5 ** i) |
|
|
alpha = max(0.2, 1 - i * 0.18) |
|
|
ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5), |
|
|
arrowprops=dict(arrowstyle='->', color='darkred', |
|
|
lw=thickness, alpha=alpha)) |
|
|
|
|
|
|
|
|
ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkred') |
|
|
ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkred') |
|
|
|
|
|
|
|
|
ax.annotate('Gradient shrinks\nat each layer!', xy=(8, 0.5), fontsize=11, |
|
|
color='darkred', style='italic') |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.set_xlim(0, 12) |
|
|
ax.set_ylim(0, 3.5) |
|
|
ax.set_aspect('equal') |
|
|
ax.axis('off') |
|
|
ax.set_title('ResMLP: Gradient Has a Direct HIGHWAY\n(Skip connections = express lane)', |
|
|
fontsize=14, fontweight='bold', color='#2980b9', pad=20) |
|
|
|
|
|
|
|
|
ax.plot([1, 11], [2.8, 2.8], color='#27ae60', linewidth=6, alpha=0.8) |
|
|
ax.annotate('', xy=(11, 2.8), xytext=(10.5, 2.8), |
|
|
arrowprops=dict(arrowstyle='->', color='#27ae60', lw=3)) |
|
|
ax.text(6, 3.2, '✓ GRADIENT HIGHWAY (Identity Path)', ha='center', fontsize=12, |
|
|
color='#27ae60', fontweight='bold') |
|
|
|
|
|
|
|
|
for i in range(6): |
|
|
x = 1 + i * 1.8 |
|
|
box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05", |
|
|
facecolor='#3498db', edgecolor='darkblue', linewidth=2) |
|
|
ax.add_patch(box) |
|
|
ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11, |
|
|
color='white', fontweight='bold') |
|
|
|
|
|
|
|
|
if i < 5: |
|
|
ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5), |
|
|
arrowprops=dict(arrowstyle='->', color='darkblue', lw=2)) |
|
|
|
|
|
|
|
|
ax.plot([x + 0.6, x + 0.6], [2, 2.8], color='#27ae60', linewidth=2, alpha=0.5) |
|
|
|
|
|
ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkblue') |
|
|
ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkblue') |
|
|
|
|
|
|
|
|
ax.annotate('Gradient flows on highway\neven if layers block it!', xy=(8, 0.3), |
|
|
fontsize=11, color='#27ae60', style='italic') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/3_highway_concept.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 3] Highway concept visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_chain_rule(): |
|
|
"""Visualize the chain rule multiplication effect""" |
|
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(14, 7)) |
|
|
|
|
|
|
|
|
num_layers = 20 |
|
|
|
|
|
|
|
|
plain_layer_grad = 0.7 |
|
|
plain_cumulative = [1.0] |
|
|
for i in range(num_layers): |
|
|
plain_cumulative.append(plain_cumulative[-1] * plain_layer_grad) |
|
|
|
|
|
|
|
|
res_layer_contrib = 0.05 |
|
|
res_cumulative = [1.0] |
|
|
for i in range(num_layers): |
|
|
|
|
|
res_cumulative.append(res_cumulative[-1] * (1.0 + res_layer_contrib * (0.9 ** i))) |
|
|
|
|
|
layers = range(num_layers + 1) |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
ax.semilogy(layers, plain_cumulative, 'o-', color='#e74c3c', linewidth=2, |
|
|
markersize=8, label='PlainMLP: 0.7 × 0.7 × 0.7 × ...') |
|
|
ax.semilogy(layers, res_cumulative, 's-', color='#3498db', linewidth=2, |
|
|
markersize=8, label='ResMLP: (1+ε) × (1+ε) × ...') |
|
|
|
|
|
ax.set_xlabel('Layers Traversed (backward from loss)', fontsize=12) |
|
|
ax.set_ylabel('Cumulative Gradient Scale (log)', fontsize=12) |
|
|
ax.set_title('Chain Rule: Why Gradients Vanish\n(Multiplication Effect)', fontsize=14, fontweight='bold') |
|
|
ax.legend(fontsize=11) |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.set_ylim(1e-8, 10) |
|
|
|
|
|
|
|
|
ax.annotate(f'After 20 layers:\n{plain_cumulative[-1]:.1e}', |
|
|
xy=(20, plain_cumulative[-1]), xytext=(15, 1e-6), |
|
|
fontsize=10, color='#c0392b', |
|
|
arrowprops=dict(arrowstyle='->', color='#c0392b')) |
|
|
ax.annotate(f'After 20 layers:\n{res_cumulative[-1]:.2f}', |
|
|
xy=(20, res_cumulative[-1]), xytext=(15, 3), |
|
|
fontsize=10, color='#2980b9', |
|
|
arrowprops=dict(arrowstyle='->', color='#2980b9')) |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.axis('off') |
|
|
ax.set_xlim(0, 10) |
|
|
ax.set_ylim(0, 10) |
|
|
|
|
|
ax.text(5, 9, 'The Math Behind It', fontsize=16, fontweight='bold', |
|
|
ha='center', va='center') |
|
|
|
|
|
|
|
|
ax.text(5, 7.5, 'PlainMLP Gradient:', fontsize=13, fontweight='bold', |
|
|
ha='center', color='#c0392b') |
|
|
ax.text(5, 6.5, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} \frac{\partial x_{i+1}}{\partial x_i}$', |
|
|
fontsize=14, ha='center', color='#c0392b') |
|
|
ax.text(5, 5.5, '= (small) × (small) × ... × (small) = TINY!', |
|
|
fontsize=11, ha='center', color='#c0392b', style='italic') |
|
|
|
|
|
|
|
|
ax.text(5, 4, 'ResMLP Gradient:', fontsize=13, fontweight='bold', |
|
|
ha='center', color='#2980b9') |
|
|
ax.text(5, 3, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} (1 + \frac{\partial f_i}{\partial x_i})$', |
|
|
fontsize=14, ha='center', color='#2980b9') |
|
|
ax.text(5, 2, '= (1+ε) × (1+ε) × ... = PRESERVED!', |
|
|
fontsize=11, ha='center', color='#2980b9', style='italic') |
|
|
|
|
|
|
|
|
box = FancyBboxPatch((1, 0.3), 8, 1.2, boxstyle="round,pad=0.1", |
|
|
facecolor='#f9e79f', edgecolor='#f39c12', linewidth=2) |
|
|
ax.add_patch(box) |
|
|
ax.text(5, 0.9, '💡 Key Insight: The "+x" in residual adds a "1" to each gradient term,\n' |
|
|
'preventing the product from shrinking to zero!', |
|
|
fontsize=11, ha='center', va='center', fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/4_chain_rule.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 4] Chain rule visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_layer_transformation(): |
|
|
"""Show what happens to a single input vector through layers""" |
|
|
|
|
|
|
|
|
class PlainMLP(nn.Module): |
|
|
def __init__(self, dim, num_layers): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(num_layers): |
|
|
layer = nn.Linear(dim, dim) |
|
|
nn.init.kaiming_normal_(layer.weight) |
|
|
layer.weight.data *= 1.0 / np.sqrt(num_layers) |
|
|
nn.init.zeros_(layer.bias) |
|
|
self.layers.append(layer) |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
def forward_with_intermediates(self, x): |
|
|
intermediates = [x.clone()] |
|
|
for layer in self.layers: |
|
|
x = self.activation(layer(x)) |
|
|
intermediates.append(x.clone()) |
|
|
return intermediates |
|
|
|
|
|
class ResMLP(nn.Module): |
|
|
def __init__(self, dim, num_layers): |
|
|
super().__init__() |
|
|
self.layers = nn.ModuleList() |
|
|
for _ in range(num_layers): |
|
|
layer = nn.Linear(dim, dim) |
|
|
nn.init.kaiming_normal_(layer.weight) |
|
|
layer.weight.data *= 1.0 / np.sqrt(num_layers) |
|
|
nn.init.zeros_(layer.bias) |
|
|
self.layers.append(layer) |
|
|
self.activation = nn.ReLU() |
|
|
|
|
|
def forward_with_intermediates(self, x): |
|
|
intermediates = [x.clone()] |
|
|
for layer in self.layers: |
|
|
x = x + self.activation(layer(x)) |
|
|
intermediates.append(x.clone()) |
|
|
return intermediates |
|
|
|
|
|
|
|
|
dim = 64 |
|
|
num_layers = 20 |
|
|
plain = PlainMLP(dim, num_layers) |
|
|
res = ResMLP(dim, num_layers) |
|
|
|
|
|
|
|
|
x = torch.randn(1, dim) * 0.5 |
|
|
|
|
|
|
|
|
plain_ints = plain.forward_with_intermediates(x) |
|
|
res_ints = res.forward_with_intermediates(x) |
|
|
|
|
|
|
|
|
plain_norms = [p.norm().item() for p in plain_ints] |
|
|
res_norms = [r.norm().item() for r in res_ints] |
|
|
|
|
|
plain_2d = [p[0, :2].detach().numpy() for p in plain_ints] |
|
|
res_2d = [r[0, :2].detach().numpy() for r in res_ints] |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
|
|
|
|
|
|
|
|
ax = axes[0, 0] |
|
|
layers = range(len(plain_norms)) |
|
|
ax.plot(layers, plain_norms, 'o-', color='#e74c3c', linewidth=2, markersize=6, label='PlainMLP') |
|
|
ax.plot(layers, res_norms, 's-', color='#3498db', linewidth=2, markersize=6, label='ResMLP') |
|
|
ax.set_xlabel('Layer (0=Input)', fontsize=12) |
|
|
ax.set_ylabel('Vector Magnitude (L2 norm)', fontsize=12) |
|
|
ax.set_title('Signal Magnitude Through Network', fontsize=13, fontweight='bold') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[0, 1] |
|
|
|
|
|
|
|
|
plain_x = [p[0] for p in plain_2d] |
|
|
plain_y = [p[1] for p in plain_2d] |
|
|
ax.plot(plain_x, plain_y, 'o-', color='#e74c3c', linewidth=1.5, markersize=4, |
|
|
alpha=0.7, label='PlainMLP path') |
|
|
ax.scatter(plain_x[0], plain_y[0], s=100, color='#e74c3c', marker='*', zorder=5) |
|
|
ax.scatter(plain_x[-1], plain_y[-1], s=100, color='#e74c3c', marker='X', zorder=5) |
|
|
|
|
|
|
|
|
res_x = [r[0] for r in res_2d] |
|
|
res_y = [r[1] for r in res_2d] |
|
|
ax.plot(res_x, res_y, 's-', color='#3498db', linewidth=1.5, markersize=4, |
|
|
alpha=0.7, label='ResMLP path') |
|
|
ax.scatter(res_x[0], res_y[0], s=100, color='#3498db', marker='*', zorder=5) |
|
|
ax.scatter(res_x[-1], res_y[-1], s=100, color='#3498db', marker='X', zorder=5) |
|
|
|
|
|
ax.set_xlabel('Dimension 1', fontsize=12) |
|
|
ax.set_ylabel('Dimension 2', fontsize=12) |
|
|
ax.set_title('2D Projection of Vector Path\n(★=start, ✕=end)', fontsize=13, fontweight='bold') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3) |
|
|
ax.axvline(x=0, color='gray', linestyle='-', alpha=0.3) |
|
|
|
|
|
|
|
|
ax = axes[1, 0] |
|
|
plain_acts = np.array([p[0, :32].detach().numpy() for p in plain_ints]) |
|
|
im = ax.imshow(plain_acts.T, aspect='auto', cmap='Reds', interpolation='nearest') |
|
|
ax.set_xlabel('Layer', fontsize=12) |
|
|
ax.set_ylabel('Dimension (first 32)', fontsize=12) |
|
|
ax.set_title('PlainMLP: Activations Die Out', fontsize=13, fontweight='bold', color='#c0392b') |
|
|
plt.colorbar(im, ax=ax, label='Activation Value') |
|
|
|
|
|
|
|
|
ax = axes[1, 1] |
|
|
res_acts = np.array([r[0, :32].detach().numpy() for r in res_ints]) |
|
|
im = ax.imshow(res_acts.T, aspect='auto', cmap='Blues', interpolation='nearest') |
|
|
ax.set_xlabel('Layer', fontsize=12) |
|
|
ax.set_ylabel('Dimension (first 32)', fontsize=12) |
|
|
ax.set_title('ResMLP: Activations Stay Alive', fontsize=13, fontweight='bold', color='#2980b9') |
|
|
plt.colorbar(im, ax=ax, label='Activation Value') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/5_layer_transformation.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 5] Layer transformation visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_learning_comparison(): |
|
|
"""Show what each model learned (or didn't learn)""" |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
|
|
|
|
|
plain_losses = results['plain_mlp']['loss_history'] |
|
|
res_losses = results['res_mlp']['loss_history'] |
|
|
|
|
|
|
|
|
ax = axes[0, 0] |
|
|
steps = range(len(plain_losses)) |
|
|
ax.plot(steps, plain_losses, color='#e74c3c', linewidth=2, label='PlainMLP') |
|
|
ax.plot(steps, res_losses, color='#3498db', linewidth=2, label='ResMLP') |
|
|
ax.set_xlabel('Training Steps', fontsize=12) |
|
|
ax.set_ylabel('MSE Loss', fontsize=12) |
|
|
ax.set_title('Learning Progress', fontsize=13, fontweight='bold') |
|
|
ax.set_yscale('log') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3) |
|
|
|
|
|
|
|
|
ax.axvspan(0, 50, alpha=0.1, color='gray') |
|
|
ax.text(25, 5, 'Early\nTraining', ha='center', fontsize=9, color='gray') |
|
|
ax.axvspan(450, 500, alpha=0.1, color='green') |
|
|
ax.text(475, 5, 'Final', ha='center', fontsize=9, color='gray') |
|
|
|
|
|
|
|
|
ax = axes[0, 1] |
|
|
|
|
|
plain_initial = plain_losses[0] |
|
|
plain_final = plain_losses[-1] |
|
|
res_initial = res_losses[0] |
|
|
res_final = res_losses[-1] |
|
|
|
|
|
plain_reduction = (1 - plain_final / plain_initial) * 100 |
|
|
res_reduction = (1 - res_final / res_initial) * 100 |
|
|
|
|
|
bars = ax.bar(['PlainMLP', 'ResMLP'], [plain_reduction, res_reduction], |
|
|
color=['#e74c3c', '#3498db'], edgecolor='black', linewidth=2) |
|
|
ax.set_ylabel('Loss Reduction (%)', fontsize=12) |
|
|
ax.set_title('How Much Did Each Model Learn?', fontsize=13, fontweight='bold') |
|
|
ax.set_ylim(0, 110) |
|
|
|
|
|
|
|
|
ax.text(0, plain_reduction + 3, f'{plain_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold') |
|
|
ax.text(1, res_reduction + 3, f'{res_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold') |
|
|
|
|
|
|
|
|
ax.text(0, plain_reduction/2, 'FAILED\nTO LEARN', ha='center', va='center', |
|
|
fontsize=11, color='white', fontweight='bold') |
|
|
ax.text(1, res_reduction/2, 'LEARNED\nSUCCESSFULLY', ha='center', va='center', |
|
|
fontsize=11, color='white', fontweight='bold') |
|
|
|
|
|
|
|
|
ax = axes[1, 0] |
|
|
|
|
|
plain_grads = results['plain_mlp']['gradient_norms'] |
|
|
res_grads = results['res_mlp']['gradient_norms'] |
|
|
|
|
|
layers = range(1, 21) |
|
|
width = 0.35 |
|
|
|
|
|
ax.bar([l - width/2 for l in layers], plain_grads, width, label='PlainMLP', |
|
|
color='#e74c3c', alpha=0.8) |
|
|
ax.bar([l + width/2 for l in layers], res_grads, width, label='ResMLP', |
|
|
color='#3498db', alpha=0.8) |
|
|
|
|
|
ax.set_xlabel('Layer', fontsize=12) |
|
|
ax.set_ylabel('Gradient Magnitude', fontsize=12) |
|
|
ax.set_title('Final Gradient Distribution by Layer', fontsize=13, fontweight='bold') |
|
|
ax.set_yscale('log') |
|
|
ax.legend() |
|
|
ax.grid(True, alpha=0.3, axis='y') |
|
|
|
|
|
|
|
|
ax = axes[1, 1] |
|
|
ax.axis('off') |
|
|
ax.set_xlim(0, 10) |
|
|
ax.set_ylim(0, 10) |
|
|
|
|
|
ax.text(5, 9.5, '📊 Summary: Why Residuals Work', fontsize=16, fontweight='bold', ha='center') |
|
|
|
|
|
|
|
|
box1 = FancyBboxPatch((0.5, 5), 4, 3.5, boxstyle="round,pad=0.1", |
|
|
facecolor='#fadbd8', edgecolor='#c0392b', linewidth=2) |
|
|
ax.add_patch(box1) |
|
|
ax.text(2.5, 8, 'PlainMLP ❌', fontsize=13, fontweight='bold', ha='center', color='#c0392b') |
|
|
ax.text(2.5, 7, f'• Loss: {plain_final:.3f}', fontsize=11, ha='center') |
|
|
ax.text(2.5, 6.3, f'• Gradient L1: {plain_grads[0]:.1e}', fontsize=11, ha='center') |
|
|
ax.text(2.5, 5.6, '• Status: UNTRAINABLE', fontsize=11, ha='center', color='#c0392b') |
|
|
|
|
|
|
|
|
box2 = FancyBboxPatch((5.5, 5), 4, 3.5, boxstyle="round,pad=0.1", |
|
|
facecolor='#d4e6f1', edgecolor='#2980b9', linewidth=2) |
|
|
ax.add_patch(box2) |
|
|
ax.text(7.5, 8, 'ResMLP ✓', fontsize=13, fontweight='bold', ha='center', color='#2980b9') |
|
|
ax.text(7.5, 7, f'• Loss: {res_final:.3f}', fontsize=11, ha='center') |
|
|
ax.text(7.5, 6.3, f'• Gradient L1: {res_grads[0]:.1e}', fontsize=11, ha='center') |
|
|
ax.text(7.5, 5.6, '• Status: TRAINED', fontsize=11, ha='center', color='#2980b9') |
|
|
|
|
|
|
|
|
box3 = FancyBboxPatch((1, 0.5), 8, 3.5, boxstyle="round,pad=0.1", |
|
|
facecolor='#fef9e7', edgecolor='#f39c12', linewidth=2) |
|
|
ax.add_patch(box3) |
|
|
ax.text(5, 3.5, '💡 The Residual Connection:', fontsize=13, fontweight='bold', ha='center') |
|
|
ax.text(5, 2.6, '1. Creates a "gradient highway" for backpropagation', fontsize=11, ha='center') |
|
|
ax.text(5, 1.9, '2. Preserves signal magnitude through forward pass', fontsize=11, ha='center') |
|
|
ax.text(5, 1.2, '3. Allows training of very deep networks', fontsize=11, ha='center') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig('plots_micro/6_learning_comparison.png', dpi=150, bbox_inches='tight') |
|
|
plt.close() |
|
|
print("[Plot 6] Learning comparison visualization saved") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("=" * 60) |
|
|
print("Creating Micro-World Visualizations") |
|
|
print("=" * 60) |
|
|
|
|
|
plot_signal_flow() |
|
|
plot_gradient_flow() |
|
|
plot_highway_concept() |
|
|
plot_chain_rule() |
|
|
plot_layer_transformation() |
|
|
plot_learning_comparison() |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("All visualizations saved to plots_micro/") |
|
|
print("=" * 60) |
|
|
|