i_like_purple / visual_nn_nodes.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M Visual Neural Network β€” Node & Connection Style
==========================================================
Generates a classic neural network diagram (like the user's reference)
with nodes and connection lines, accurately showing the GPT-300M architecture
with correct parameter calculations at each layer.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
# ═══════════════════════════════════════════════════════════════════════
# GPT-300M ARCHITECTURE β€” ACCURATE PARAMETER COUNTS
# ═══════════════════════════════════════════════════════════════════════
# All layer definitions with EXACT parameter counts
# Format: (layer_name, display_nodes, actual_neurons, params_in_layer, color)
VOCAB_SIZE = 32_000
D_MODEL = 1_024
N_HEADS = 16
HEAD_DIM = 64
D_FF = 4_096
N_LAYERS = 24
# Parameter calculations per component:
embed_params = VOCAB_SIZE * D_MODEL # 32,768,000
# RoPE has no learned parameters (precomputed sin/cos)
rope_params = 0
# Per transformer layer:
qkv_params = 3 * D_MODEL * D_MODEL # 3,145,728 (Q, K, V projections)
out_proj_params = D_MODEL * D_MODEL # 1,048,576 (output projection)
attn_total = qkv_params + out_proj_params # 4,194,304
ffn_up_params = D_MODEL * D_FF # 4,194,304 (up projection)
ffn_down_params = D_FF * D_MODEL # 4,194,304 (down projection)
ffn_total = ffn_up_params + ffn_down_params # 8,388,608
rmsnorm_params = D_MODEL * 2 # 2,048 (2 norms per layer)
layer_total = attn_total + ffn_total + rmsnorm_params # 12,584,960
all_layers_total = layer_total * N_LAYERS # 302,039,040
final_norm_params = D_MODEL # 1,024
# LM Head is weight-tied with embedding, so 0 extra params
lm_head_params = 0 # (tied)
TOTAL_PARAMS = embed_params + all_layers_total + final_norm_params + lm_head_params
# = 32,768,000 + 302,039,040 + 1,024 = 334,808,064
# With weight tying, unique params β‰ˆ 334,808,064
# ═══════════════════════════════════════════════════════════════════════
# LAYER DEFINITIONS FOR VISUALIZATION
# ═══════════════════════════════════════════════════════════════════════
# (name, nodes_to_display, actual_size, params_to_this_layer, color)
LAYERS = [
("Input Tokens", 10, VOCAB_SIZE, 0, "#4CAF50"), # Green
("Token Embedding", 10, D_MODEL, embed_params, "#2196F3"), # Blue
("RoPE Positions", 10, D_MODEL, 0, "#00BCD4"), # Cyan
# Show 3 representative transformer layers (of 24)
("Layer 1: Attention Q,K,V", 12, D_MODEL, qkv_params, "#FF9800"), # Orange
("Layer 1: Attention Out", 10, D_MODEL, out_proj_params, "#FF9800"),
("Layer 1: FFN Up", 14, D_FF, ffn_up_params, "#8BC34A"), # Light green
("Layer 1: FFN Down", 10, D_MODEL, ffn_down_params, "#8BC34A"),
("Layer 2–23: Γ—22 Blocks", 12, D_MODEL, layer_total * 22, "#9C27B0"), # Purple
("Layer 24: Attention", 12, D_MODEL, attn_total, "#FF5722"), # Deep orange
("Layer 24: FFN", 14, D_FF, ffn_total, "#009688"), # Teal
("Layer 24: Output", 10, D_MODEL, rmsnorm_params, "#009688"),
("Final RMSNorm", 10, D_MODEL, final_norm_params, "#E91E63"), # Pink
("LM Head (tied)", 10, VOCAB_SIZE, lm_head_params, "#F44336"), # Red
("Output Probabilities", 1, VOCAB_SIZE, 0, "#F44336"), # Red
]
def draw_neural_network(save_path="neural_network.png"):
fig, ax = plt.subplots(figsize=(22, 30), facecolor="#0D1117")
ax.set_facecolor("#0D1117")
n_layers = len(LAYERS)
y_positions = np.linspace(0.92, 0.04, n_layers)
# Spacing
x_center = 0.5
max_spread = 0.38
all_node_positions = [] # Store (x_list, y) for connections
running_params = 0
for i, (name, n_display, actual_size, params, color) in enumerate(LAYERS):
y = y_positions[i]
running_params += params
# Calculate x positions for nodes
if n_display == 1:
xs = [x_center]
else:
xs = np.linspace(x_center - max_spread, x_center + max_spread, n_display)
all_node_positions.append((xs, y))
# Draw connections to previous layer
if i > 0:
prev_xs, prev_y = all_node_positions[i - 1]
# Limit connections for readability
max_connections = 200
step_curr = max(1, len(xs) // 12)
step_prev = max(1, len(prev_xs) // 12)
conn_count = 0
for px in prev_xs[::step_prev]:
for cx in xs[::step_curr]:
if conn_count > max_connections:
break
ax.plot(
[px, cx], [prev_y, y],
color=color, alpha=0.22, linewidth=0.6,
transform=ax.transAxes, zorder=1,
)
conn_count += 1
# Draw nodes
node_radius = 0.01 if n_display <= 12 else 0.008
if n_display == 1:
node_radius = 0.016
for x in xs:
circle = plt.Circle(
(x, y), node_radius,
facecolor=color, edgecolor="white",
linewidth=0.6, alpha=0.95,
transform=ax.transAxes, zorder=3,
)
ax.add_patch(circle)
# Draw "+N" indicator if actual size > displayed
if actual_size > n_display and n_display > 1:
extra = actual_size - n_display
if extra > 0:
ax.text(
xs[-1] + 0.03, y,
f"(+{extra:,})",
transform=ax.transAxes,
fontsize=7, color="#8B949E",
ha="left", va="center",
fontfamily="monospace",
)
# Layer label (left side)
ax.text(
0.02, y,
name,
transform=ax.transAxes,
fontsize=9, fontweight="bold",
color="#E6EDF3",
ha="left", va="center",
fontfamily="monospace",
)
# Parameter count (right side)
if params > 0:
param_text = f"{params:,} params"
ax.text(
0.98, y,
param_text,
transform=ax.transAxes,
fontsize=8,
color=color,
ha="right", va="center",
fontfamily="monospace",
fontweight="bold",
)
# Running total (far right, smaller)
if running_params > 0:
ax.text(
0.98, y - 0.012,
f"Ξ£ {running_params / 1e6:.1f}M",
transform=ax.transAxes,
fontsize=6.5,
color="#8B949E",
ha="right", va="center",
fontfamily="monospace",
)
# ── Title ──────────────────────────────────────────────────────
ax.text(
0.5, 0.97,
"GPT-300M Neural Network",
transform=ax.transAxes,
fontsize=24, fontweight="bold",
color="#E6EDF3", ha="center", va="center",
fontfamily="monospace",
)
ax.text(
0.5, 0.955,
f"Total: {TOTAL_PARAMS:,} parameters β€’ {N_LAYERS} transformer layers β€’ "
f"{N_HEADS} attention heads β€’ d_model={D_MODEL}",
transform=ax.transAxes,
fontsize=9, color="#8B949E", ha="center", va="center",
fontfamily="monospace",
)
# ── Parameter Summary Box ──────────────────────────────────────
summary_y = 0.005
summary_text = (
f"β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€ Parameter Summary ───────────────┐\n"
f"β”‚ Token Embedding: {embed_params:>13,} ({embed_params/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
f"β”‚ Attention (Γ—{N_LAYERS}): {attn_total*N_LAYERS:>13,} ({attn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
f"β”‚ Feed-Forward (Γ—{N_LAYERS}): {ffn_total*N_LAYERS:>13,} ({ffn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
f"β”‚ RMSNorm (Γ—{N_LAYERS}+1): {rmsnorm_params*N_LAYERS+final_norm_params:>13,} ({(rmsnorm_params*N_LAYERS+final_norm_params)/TOTAL_PARAMS*100:4.1f}%) β”‚\n"
f"β”‚ LM Head (tied): {'0 (shared)':>13} β”‚\n"
f"β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n"
f"β”‚ TOTAL: {TOTAL_PARAMS:>13,} (100%) β”‚\n"
f"β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜"
)
ax.text(
0.5, summary_y,
summary_text,
transform=ax.transAxes,
fontsize=8, color="#58A6FF",
ha="center", va="bottom",
fontfamily="monospace",
bbox=dict(boxstyle="round,pad=0.8", facecolor="#161B22",
edgecolor="#30363D", linewidth=1),
)
# ── Legend ──────────────────────────────────────────────────────
legend_items = [
("#4CAF50", "Input / Tokenization"),
("#2196F3", "Embeddings"),
("#FF9800", "Self-Attention"),
("#8BC34A", "Feed-Forward (GELU)"),
("#9C27B0", "Collapsed Layers (Γ—22)"),
("#E91E63", "Normalization"),
("#F44336", "Output / LM Head"),
]
for j, (c, label) in enumerate(legend_items):
lx = 0.02
ly = 0.035 - j * 0.015
circle = plt.Circle(
(lx, ly), 0.004,
facecolor=c, edgecolor="white", linewidth=0.3,
transform=ax.transAxes, zorder=5,
)
ax.add_patch(circle)
ax.text(
lx + 0.012, ly, label,
transform=ax.transAxes,
fontsize=7, color="#C9D1D9", va="center",
fontfamily="monospace",
)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
plt.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor="#0D1117", edgecolor="none")
print(f"Saved: {save_path}")
plt.close()
# ═══════════════════════════════════════════════════════════════════════
# ALSO: A cleaner "zoomed in" single-layer view
# ═══════════════════════════════════════════════════════════════════════
def draw_single_layer_detail(save_path="layer_detail.png"):
"""Draw a detailed view of one transformer layer with node connections."""
fig, ax = plt.subplots(figsize=(20, 14), facecolor="#0D1117")
ax.set_facecolor("#0D1117")
# One transformer layer breakdown:
# Input (1024) β†’ Q,K,V (3Γ—1024) β†’ Attention Heads (16Γ—64) β†’ Output Proj (1024)
# β†’ RMSNorm (1024) β†’ FFN Up (4096) β†’ GELU β†’ FFN Down (1024) β†’ Output (1024)
sub_layers = [
("Input\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
("Query\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF6B6B"),
("Key\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#4ECDC4"),
("Value\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#45B7D1"),
("Attention Heads\n(16Γ—64)", 16, D_MODEL, 0, "#FF9800"),
("Attn Output\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF9800"),
("βŠ• Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
("FFN Up (GELU)\n(d=4,096)", 14, D_FF, D_MODEL*D_FF, "#8BC34A"),
("FFN Down\n(d=1,024)", 8, D_MODEL, D_FF*D_MODEL, "#8BC34A"),
("βŠ• Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"),
("Layer Output\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"),
]
n = len(sub_layers)
y_positions = np.linspace(0.9, 0.08, n)
x_center = 0.5
max_spread = 0.32
all_pos = []
for i, (name, n_nodes, actual, params, color) in enumerate(sub_layers):
y = y_positions[i]
xs = np.linspace(x_center - max_spread, x_center + max_spread, n_nodes)
all_pos.append((xs, y))
# Connections
if i > 0:
prev_xs, prev_y = all_pos[i-1]
step_c = max(1, len(xs) // 10)
step_p = max(1, len(prev_xs) // 10)
for px in prev_xs[::step_p]:
for cx in xs[::step_c]:
ax.plot([px, cx], [prev_y, y],
color=color, alpha=0.2, linewidth=0.7,
transform=ax.transAxes, zorder=1)
# Nodes
r = 0.011 if n_nodes <= 10 else 0.009
for x in xs:
c = plt.Circle((x, y), r, facecolor=color, edgecolor="white",
linewidth=0.6, alpha=0.95,
transform=ax.transAxes, zorder=3)
ax.add_patch(c)
# Overflow indicator
if actual > n_nodes:
ax.text(xs[-1] + 0.025, y, f"(+{actual - n_nodes:,})",
transform=ax.transAxes, fontsize=7, color="#8B949E",
ha="left", va="center", fontfamily="monospace")
# Label
ax.text(0.03, y, name, transform=ax.transAxes,
fontsize=9, fontweight="bold", color="#E6EDF3",
ha="left", va="center", fontfamily="monospace")
# Params
if params > 0:
ax.text(0.97, y, f"{params:,}", transform=ax.transAxes,
fontsize=8, color=color, ha="right", va="center",
fontfamily="monospace", fontweight="bold")
# Title
ax.text(0.5, 0.96, "Single Transformer Layer β€” Detailed View",
transform=ax.transAxes, fontsize=18, fontweight="bold",
color="#E6EDF3", ha="center", fontfamily="monospace")
ax.text(0.5, 0.935,
f"Parameters per layer: {layer_total:,} β€’ Γ—{N_LAYERS} layers = {all_layers_total:,} total",
transform=ax.transAxes, fontsize=9, color="#8B949E",
ha="center", fontfamily="monospace")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.axis("off")
plt.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor="#0D1117", edgecolor="none")
print(f"Saved: {save_path}")
plt.close()
if __name__ == "__main__":
import os
os.makedirs("viz", exist_ok=True)
print("=" * 50)
print(" GPT-300M Parameter Verification")
print("=" * 50)
print(f" Token Embedding: {embed_params:>13,}")
print(f" Per-layer Attention: {attn_total:>13,}")
print(f" Per-layer FFN: {ffn_total:>13,}")
print(f" Per-layer Norm: {rmsnorm_params:>13,}")
print(f" Per-layer Total: {layer_total:>13,}")
print(f" All {N_LAYERS} layers: {all_layers_total:>13,}")
print(f" Final Norm: {final_norm_params:>13,}")
print(f" LM Head (tied): {'0 (shared)':>13}")
print(f" ─────────────────────────────────")
print(f" TOTAL: {TOTAL_PARAMS:>13,}")
print(f" β‰ˆ {TOTAL_PARAMS / 1e6:.1f}M parameters")
print("=" * 50)
print("\nGenerating full network diagram...")
draw_neural_network("viz/neural_network_full.png")
print("Generating single-layer detail...")
draw_single_layer_detail("viz/neural_network_layer.png")
print("\nDone!")