""" GPT-300M Visual Neural Network — Node & Connection Style ========================================================== Generates a classic neural network diagram (like the user's reference) with nodes and connection lines, accurately showing the GPT-300M architecture with correct parameter calculations at each layer. """ import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches import numpy as np # ═══════════════════════════════════════════════════════════════════════ # GPT-300M ARCHITECTURE — ACCURATE PARAMETER COUNTS # ═══════════════════════════════════════════════════════════════════════ # All layer definitions with EXACT parameter counts # Format: (layer_name, display_nodes, actual_neurons, params_in_layer, color) VOCAB_SIZE = 32_000 D_MODEL = 1_024 N_HEADS = 16 HEAD_DIM = 64 D_FF = 4_096 N_LAYERS = 24 # Parameter calculations per component: embed_params = VOCAB_SIZE * D_MODEL # 32,768,000 # RoPE has no learned parameters (precomputed sin/cos) rope_params = 0 # Per transformer layer: qkv_params = 3 * D_MODEL * D_MODEL # 3,145,728 (Q, K, V projections) out_proj_params = D_MODEL * D_MODEL # 1,048,576 (output projection) attn_total = qkv_params + out_proj_params # 4,194,304 ffn_up_params = D_MODEL * D_FF # 4,194,304 (up projection) ffn_down_params = D_FF * D_MODEL # 4,194,304 (down projection) ffn_total = ffn_up_params + ffn_down_params # 8,388,608 rmsnorm_params = D_MODEL * 2 # 2,048 (2 norms per layer) layer_total = attn_total + ffn_total + rmsnorm_params # 12,584,960 all_layers_total = layer_total * N_LAYERS # 302,039,040 final_norm_params = D_MODEL # 1,024 # LM Head is weight-tied with embedding, so 0 extra params lm_head_params = 0 # (tied) TOTAL_PARAMS = embed_params + all_layers_total + final_norm_params + lm_head_params # = 32,768,000 + 302,039,040 + 1,024 = 334,808,064 # With weight tying, unique params ≈ 334,808,064 # ═══════════════════════════════════════════════════════════════════════ # LAYER DEFINITIONS FOR VISUALIZATION # ═══════════════════════════════════════════════════════════════════════ # (name, nodes_to_display, actual_size, params_to_this_layer, color) LAYERS = [ ("Input Tokens", 10, VOCAB_SIZE, 0, "#4CAF50"), # Green ("Token Embedding", 10, D_MODEL, embed_params, "#2196F3"), # Blue ("RoPE Positions", 10, D_MODEL, 0, "#00BCD4"), # Cyan # Show 3 representative transformer layers (of 24) ("Layer 1: Attention Q,K,V", 12, D_MODEL, qkv_params, "#FF9800"), # Orange ("Layer 1: Attention Out", 10, D_MODEL, out_proj_params, "#FF9800"), ("Layer 1: FFN Up", 14, D_FF, ffn_up_params, "#8BC34A"), # Light green ("Layer 1: FFN Down", 10, D_MODEL, ffn_down_params, "#8BC34A"), ("Layer 2–23: ×22 Blocks", 12, D_MODEL, layer_total * 22, "#9C27B0"), # Purple ("Layer 24: Attention", 12, D_MODEL, attn_total, "#FF5722"), # Deep orange ("Layer 24: FFN", 14, D_FF, ffn_total, "#009688"), # Teal ("Layer 24: Output", 10, D_MODEL, rmsnorm_params, "#009688"), ("Final RMSNorm", 10, D_MODEL, final_norm_params, "#E91E63"), # Pink ("LM Head (tied)", 10, VOCAB_SIZE, lm_head_params, "#F44336"), # Red ("Output Probabilities", 1, VOCAB_SIZE, 0, "#F44336"), # Red ] def draw_neural_network(save_path="neural_network.png"): fig, ax = plt.subplots(figsize=(22, 30), facecolor="#0D1117") ax.set_facecolor("#0D1117") n_layers = len(LAYERS) y_positions = np.linspace(0.92, 0.04, n_layers) # Spacing x_center = 0.5 max_spread = 0.38 all_node_positions = [] # Store (x_list, y) for connections running_params = 0 for i, (name, n_display, actual_size, params, color) in enumerate(LAYERS): y = y_positions[i] running_params += params # Calculate x positions for nodes if n_display == 1: xs = [x_center] else: xs = np.linspace(x_center - max_spread, x_center + max_spread, n_display) all_node_positions.append((xs, y)) # Draw connections to previous layer if i > 0: prev_xs, prev_y = all_node_positions[i - 1] # Limit connections for readability max_connections = 200 step_curr = max(1, len(xs) // 12) step_prev = max(1, len(prev_xs) // 12) conn_count = 0 for px in prev_xs[::step_prev]: for cx in xs[::step_curr]: if conn_count > max_connections: break ax.plot( [px, cx], [prev_y, y], color=color, alpha=0.22, linewidth=0.6, transform=ax.transAxes, zorder=1, ) conn_count += 1 # Draw nodes node_radius = 0.01 if n_display <= 12 else 0.008 if n_display == 1: node_radius = 0.016 for x in xs: circle = plt.Circle( (x, y), node_radius, facecolor=color, edgecolor="white", linewidth=0.6, alpha=0.95, transform=ax.transAxes, zorder=3, ) ax.add_patch(circle) # Draw "+N" indicator if actual size > displayed if actual_size > n_display and n_display > 1: extra = actual_size - n_display if extra > 0: ax.text( xs[-1] + 0.03, y, f"(+{extra:,})", transform=ax.transAxes, fontsize=7, color="#8B949E", ha="left", va="center", fontfamily="monospace", ) # Layer label (left side) ax.text( 0.02, y, name, transform=ax.transAxes, fontsize=9, fontweight="bold", color="#E6EDF3", ha="left", va="center", fontfamily="monospace", ) # Parameter count (right side) if params > 0: param_text = f"{params:,} params" ax.text( 0.98, y, param_text, transform=ax.transAxes, fontsize=8, color=color, ha="right", va="center", fontfamily="monospace", fontweight="bold", ) # Running total (far right, smaller) if running_params > 0: ax.text( 0.98, y - 0.012, f"Σ {running_params / 1e6:.1f}M", transform=ax.transAxes, fontsize=6.5, color="#8B949E", ha="right", va="center", fontfamily="monospace", ) # ── Title ────────────────────────────────────────────────────── ax.text( 0.5, 0.97, "GPT-300M Neural Network", transform=ax.transAxes, fontsize=24, fontweight="bold", color="#E6EDF3", ha="center", va="center", fontfamily="monospace", ) ax.text( 0.5, 0.955, f"Total: {TOTAL_PARAMS:,} parameters • {N_LAYERS} transformer layers • " f"{N_HEADS} attention heads • d_model={D_MODEL}", transform=ax.transAxes, fontsize=9, color="#8B949E", ha="center", va="center", fontfamily="monospace", ) # ── Parameter Summary Box ────────────────────────────────────── summary_y = 0.005 summary_text = ( f"┌─────────────── Parameter Summary ───────────────┐\n" f"│ Token Embedding: {embed_params:>13,} ({embed_params/TOTAL_PARAMS*100:4.1f}%) │\n" f"│ Attention (×{N_LAYERS}): {attn_total*N_LAYERS:>13,} ({attn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) │\n" f"│ Feed-Forward (×{N_LAYERS}): {ffn_total*N_LAYERS:>13,} ({ffn_total*N_LAYERS/TOTAL_PARAMS*100:4.1f}%) │\n" f"│ RMSNorm (×{N_LAYERS}+1): {rmsnorm_params*N_LAYERS+final_norm_params:>13,} ({(rmsnorm_params*N_LAYERS+final_norm_params)/TOTAL_PARAMS*100:4.1f}%) │\n" f"│ LM Head (tied): {'0 (shared)':>13} │\n" f"├─────────────────────────────────────────────────┤\n" f"│ TOTAL: {TOTAL_PARAMS:>13,} (100%) │\n" f"└─────────────────────────────────────────────────┘" ) ax.text( 0.5, summary_y, summary_text, transform=ax.transAxes, fontsize=8, color="#58A6FF", ha="center", va="bottom", fontfamily="monospace", bbox=dict(boxstyle="round,pad=0.8", facecolor="#161B22", edgecolor="#30363D", linewidth=1), ) # ── Legend ────────────────────────────────────────────────────── legend_items = [ ("#4CAF50", "Input / Tokenization"), ("#2196F3", "Embeddings"), ("#FF9800", "Self-Attention"), ("#8BC34A", "Feed-Forward (GELU)"), ("#9C27B0", "Collapsed Layers (×22)"), ("#E91E63", "Normalization"), ("#F44336", "Output / LM Head"), ] for j, (c, label) in enumerate(legend_items): lx = 0.02 ly = 0.035 - j * 0.015 circle = plt.Circle( (lx, ly), 0.004, facecolor=c, edgecolor="white", linewidth=0.3, transform=ax.transAxes, zorder=5, ) ax.add_patch(circle) ax.text( lx + 0.012, ly, label, transform=ax.transAxes, fontsize=7, color="#C9D1D9", va="center", fontfamily="monospace", ) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis("off") plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0D1117", edgecolor="none") print(f"Saved: {save_path}") plt.close() # ═══════════════════════════════════════════════════════════════════════ # ALSO: A cleaner "zoomed in" single-layer view # ═══════════════════════════════════════════════════════════════════════ def draw_single_layer_detail(save_path="layer_detail.png"): """Draw a detailed view of one transformer layer with node connections.""" fig, ax = plt.subplots(figsize=(20, 14), facecolor="#0D1117") ax.set_facecolor("#0D1117") # One transformer layer breakdown: # Input (1024) → Q,K,V (3×1024) → Attention Heads (16×64) → Output Proj (1024) # → RMSNorm (1024) → FFN Up (4096) → GELU → FFN Down (1024) → Output (1024) sub_layers = [ ("Input\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"), ("Query\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF6B6B"), ("Key\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#4ECDC4"), ("Value\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#45B7D1"), ("Attention Heads\n(16×64)", 16, D_MODEL, 0, "#FF9800"), ("Attn Output\n(d=1,024)", 8, D_MODEL, D_MODEL**2, "#FF9800"), ("⊕ Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"), ("FFN Up (GELU)\n(d=4,096)", 14, D_FF, D_MODEL*D_FF, "#8BC34A"), ("FFN Down\n(d=1,024)", 8, D_MODEL, D_FF*D_MODEL, "#8BC34A"), ("⊕ Residual + Norm", 8, D_MODEL, D_MODEL, "#E91E63"), ("Layer Output\n(d=1,024)", 8, D_MODEL, 0, "#2196F3"), ] n = len(sub_layers) y_positions = np.linspace(0.9, 0.08, n) x_center = 0.5 max_spread = 0.32 all_pos = [] for i, (name, n_nodes, actual, params, color) in enumerate(sub_layers): y = y_positions[i] xs = np.linspace(x_center - max_spread, x_center + max_spread, n_nodes) all_pos.append((xs, y)) # Connections if i > 0: prev_xs, prev_y = all_pos[i-1] step_c = max(1, len(xs) // 10) step_p = max(1, len(prev_xs) // 10) for px in prev_xs[::step_p]: for cx in xs[::step_c]: ax.plot([px, cx], [prev_y, y], color=color, alpha=0.2, linewidth=0.7, transform=ax.transAxes, zorder=1) # Nodes r = 0.011 if n_nodes <= 10 else 0.009 for x in xs: c = plt.Circle((x, y), r, facecolor=color, edgecolor="white", linewidth=0.6, alpha=0.95, transform=ax.transAxes, zorder=3) ax.add_patch(c) # Overflow indicator if actual > n_nodes: ax.text(xs[-1] + 0.025, y, f"(+{actual - n_nodes:,})", transform=ax.transAxes, fontsize=7, color="#8B949E", ha="left", va="center", fontfamily="monospace") # Label ax.text(0.03, y, name, transform=ax.transAxes, fontsize=9, fontweight="bold", color="#E6EDF3", ha="left", va="center", fontfamily="monospace") # Params if params > 0: ax.text(0.97, y, f"{params:,}", transform=ax.transAxes, fontsize=8, color=color, ha="right", va="center", fontfamily="monospace", fontweight="bold") # Title ax.text(0.5, 0.96, "Single Transformer Layer — Detailed View", transform=ax.transAxes, fontsize=18, fontweight="bold", color="#E6EDF3", ha="center", fontfamily="monospace") ax.text(0.5, 0.935, f"Parameters per layer: {layer_total:,} • ×{N_LAYERS} layers = {all_layers_total:,} total", transform=ax.transAxes, fontsize=9, color="#8B949E", ha="center", fontfamily="monospace") ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis("off") plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0D1117", edgecolor="none") print(f"Saved: {save_path}") plt.close() if __name__ == "__main__": import os os.makedirs("viz", exist_ok=True) print("=" * 50) print(" GPT-300M Parameter Verification") print("=" * 50) print(f" Token Embedding: {embed_params:>13,}") print(f" Per-layer Attention: {attn_total:>13,}") print(f" Per-layer FFN: {ffn_total:>13,}") print(f" Per-layer Norm: {rmsnorm_params:>13,}") print(f" Per-layer Total: {layer_total:>13,}") print(f" All {N_LAYERS} layers: {all_layers_total:>13,}") print(f" Final Norm: {final_norm_params:>13,}") print(f" LM Head (tied): {'0 (shared)':>13}") print(f" ─────────────────────────────────") print(f" TOTAL: {TOTAL_PARAMS:>13,}") print(f" ≈ {TOTAL_PARAMS / 1e6:.1f}M parameters") print("=" * 50) print("\nGenerating full network diagram...") draw_neural_network("viz/neural_network_full.png") print("Generating single-layer detail...") draw_single_layer_detail("viz/neural_network_layer.png") print("\nDone!")