""" GPT-300M 3D Neural Network Visualization ========================================== A 3D node-and-connection neural network diagram with depth, perspective, and accurate parameter counts. """ import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d.art3d import Line3DCollection import numpy as np # ═══════════════════════════════════════════════════════════════════════ # ACCURATE GPT-300M PARAMETERS # ═══════════════════════════════════════════════════════════════════════ VOCAB = 32_000 D = 1_024 HEADS = 16 HEAD_D = 64 D_FF = 4_096 N_LAYERS = 24 embed_p = VOCAB * D # 32,768,000 attn_p = 4 * D * D # 4,194,304 per layer ffn_p = 2 * D * D_FF # 8,388,608 per layer norm_p = 2 * D # 2,048 per layer layer_p = attn_p + ffn_p + norm_p # 12,584,960 per layer all_layers_p = layer_p * N_LAYERS # 302,039,040 final_norm_p = D # 1,024 TOTAL = embed_p + all_layers_p + final_norm_p # 334,808,064 # Layer definitions: (name, num_display_nodes, actual_neurons, params, color_hex) LAYERS = [ ("Input Tokens", 10, VOCAB, 0, "#4CAF50"), ("Token Embedding", 12, D, embed_p, "#2196F3"), ("RoPE Positions", 12, D, 0, "#00BCD4"), ("Layer 1: Attention QKV", 14, D, attn_p * 3 // 4, "#FF9800"), ("Layer 1: Attention Out", 12, D, attn_p * 1 // 4, "#FF9800"), ("Layer 1: FFN Up (GELU)", 16, D_FF, ffn_p // 2, "#8BC34A"), ("Layer 1: FFN Down", 12, D, ffn_p // 2, "#8BC34A"), ("Layers 2–23 (×22)", 14, D, layer_p * 22, "#9C27B0"), ("Layer 24: Attention", 14, D, attn_p, "#FF5722"), ("Layer 24: FFN", 16, D_FF, ffn_p, "#009688"), ("Layer 24: Norm + Out", 12, D, norm_p + final_norm_p, "#E91E63"), ("LM Head (weight-tied)", 12, VOCAB, 0, "#F44336"), ("Output Probabilities", 1, VOCAB, 0, "#FF1744"), ] def hex_to_rgb(h): h = h.lstrip("#") return tuple(int(h[i:i+2], 16) / 255.0 for i in (0, 2, 4)) def generate_3d_network(save_path="neural_network_3d.png", elev=22, azim=-65): """Generate a 3D neural network with nodes, connections, and parameter labels.""" fig = plt.figure(figsize=(28, 28), facecolor="#0a0e17") ax = fig.add_subplot(111, projection="3d", computed_zorder=False) # Dark theme for 3D axes ax.set_facecolor("#0a0e17") ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor("#0a0e17") ax.yaxis.pane.set_edgecolor("#0a0e17") ax.zaxis.pane.set_edgecolor("#0a0e17") ax.grid(False) ax.set_axis_off() ax.view_init(elev=elev, azim=azim) n_layers = len(LAYERS) y_positions = np.linspace(0, n_layers * 4.0, n_layers) # depth (layer position) all_positions = [] # list of (xs, ys_unused, zs, y_layer) running_params = 0 for i, (name, n_nodes, actual, params, color_hex) in enumerate(LAYERS): y = y_positions[i] running_params += params rgb = hex_to_rgb(color_hex) # Arrange nodes in a circle/arc for 3D effect if n_nodes == 1: xs = np.array([0.0]) zs = np.array([0.0]) else: # Spread nodes along x spread = min(n_nodes * 0.5, 7.0) xs = np.linspace(-spread, spread, n_nodes) # Slight arc for 3D depth perception zs = -0.1 * (xs ** 2) ys = np.full_like(xs, y) all_positions.append((xs, ys, zs)) # ── Draw connections to previous layer ────────────────── if i > 0: prev_xs, prev_ys, prev_zs = all_positions[i - 1] # Sample connections to avoid clutter n_prev = len(prev_xs) n_curr = len(xs) step_p = max(1, n_prev // 8) step_c = max(1, n_curr // 8) lines = [] colors_lines = [] for pi in range(0, n_prev, step_p): for ci in range(0, n_curr, step_c): lines.append([ (prev_xs[pi], prev_ys[pi], prev_zs[pi]), (xs[ci], ys[ci], zs[ci]), ]) colors_lines.append((*rgb, 0.18)) if lines: lc = Line3DCollection(lines, colors=colors_lines, linewidths=0.7) ax.add_collection3d(lc) # ── Draw nodes (spheres) ──────────────────────────────── node_size = 200 if n_nodes > 12 else 280 if n_nodes == 1: node_size = 600 ax.scatter( xs, ys, zs, c=[color_hex], s=node_size, alpha=0.95, edgecolors="white", linewidths=0.5, depthshade=True, zorder=5, ) # ── Glow effect (larger transparent scatter behind) ───── ax.scatter( xs, ys, zs, c=[color_hex], s=node_size * 3, alpha=0.08, edgecolors="none", depthshade=True, zorder=4, ) # ── Labels ────────────────────────────────────────────── label_x = xs[-1] + 1.8 if n_nodes > 1 else 2.0 ax.text( label_x, y, 0, name, fontsize=9.5, fontweight="bold", color="#E6EDF3", fontfamily="monospace", zorder=10, ) # Param count if params > 0: if params >= 1_000_000: ptxt = f"{params/1e6:.1f}M params" else: ptxt = f"{params:,} params" ax.text( label_x, y, -1.0, ptxt, fontsize=8, color=color_hex, fontfamily="monospace", fontweight="bold", zorder=10, ) # Running total if running_params > 0: ax.text( label_x, y, -1.8, f"Σ {running_params/1e6:.1f}M", fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10, ) # Overflow indicator if actual > n_nodes and n_nodes > 1: ax.text( xs[-1] + 0.5, y, zs[-1], f"(+{actual - n_nodes:,})", fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10, ) # ── Title ────────────────────────────────────────────────────── ax.text2D( 0.5, 0.96, "GPT-300M • 3D Neural Network Architecture", transform=fig.transFigure, fontsize=22, fontweight="bold", color="#E6EDF3", ha="center", fontfamily="monospace", ) ax.text2D( 0.5, 0.94, f"{TOTAL:,} parameters | {N_LAYERS} layers | {HEADS} heads | d_model={D} | d_ff={D_FF}", transform=fig.transFigure, fontsize=10, color="#8B949E", ha="center", fontfamily="monospace", ) # ── Parameter summary ────────────────────────────────────────── summary = ( f"Parameter Breakdown:\n" f" Embedding: {embed_p/1e6:>7.1f}M ({embed_p/TOTAL*100:.1f}%)\n" f" Attention ×24: {attn_p*N_LAYERS/1e6:>7.1f}M ({attn_p*N_LAYERS/TOTAL*100:.1f}%)\n" f" FFN ×24: {ffn_p*N_LAYERS/1e6:>7.1f}M ({ffn_p*N_LAYERS/TOTAL*100:.1f}%)\n" f" Norms: {(norm_p*N_LAYERS+final_norm_p)/1e6:>7.3f}M ({(norm_p*N_LAYERS+final_norm_p)/TOTAL*100:.1f}%)\n" f" LM Head: tied (0 extra)\n" f" ───────────────────────\n" f" TOTAL: {TOTAL/1e6:>7.1f}M" ) ax.text2D( 0.02, 0.06, summary, transform=fig.transFigure, fontsize=8, color="#58A6FF", fontfamily="monospace", verticalalignment="bottom", bbox=dict(boxstyle="round,pad=0.6", facecolor="#161B22", edgecolor="#30363D", linewidth=1), ) # ── Legend ────────────────────────────────────────────────────── legend_items = [ ("#4CAF50", "Input"), ("#2196F3", "Embeddings"), ("#FF9800", "Attention"), ("#8BC34A", "FFN"), ("#9C27B0", "×22 Layers"), ("#E91E63", "Norm"), ("#F44336", "Output"), ] for j, (c, l) in enumerate(legend_items): ax.text2D( 0.92, 0.30 - j * 0.025, f"● {l}", transform=fig.transFigure, fontsize=8, color=c, fontfamily="monospace", ) # Set axis limits all_x = np.concatenate([p[0] for p in all_positions]) all_y = np.concatenate([p[1] for p in all_positions]) all_z = np.concatenate([p[2] for p in all_positions]) margin = 4 ax.set_xlim(all_x.min() - margin, all_x.max() + margin + 8) ax.set_ylim(all_y.min() - margin, all_y.max() + margin) ax.set_zlim(all_z.min() - margin, all_z.max() + margin) plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0a0e17", edgecolor="none") print(f"Saved: {save_path}") plt.close() def generate_3d_single_layer(save_path="layer_3d.png", elev=18, azim=-55): """3D view of a single transformer layer internals.""" fig = plt.figure(figsize=(22, 18), facecolor="#0a0e17") ax = fig.add_subplot(111, projection="3d", computed_zorder=False) ax.set_facecolor("#0a0e17") ax.xaxis.pane.fill = False ax.yaxis.pane.fill = False ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor("#0a0e17") ax.yaxis.pane.set_edgecolor("#0a0e17") ax.zaxis.pane.set_edgecolor("#0a0e17") ax.grid(False) ax.set_axis_off() ax.view_init(elev=elev, azim=azim) sub_layers = [ ("Input (d=1024)", 10, D, 0, "#2196F3"), ("Query (d=1024)", 10, D, D*D, "#FF6B6B"), ("Key (d=1024)", 10, D, D*D, "#4ECDC4"), ("Value (d=1024)", 10, D, D*D, "#45B7D1"), ("16 Attention Heads", 16, D, 0, "#FF9800"), ("Attn Output (d=1024)", 10, D, D*D, "#FFA726"), ("⊕ Residual + RMSNorm", 10, D, D, "#E91E63"), ("FFN Up → GELU (d=4096)", 16, D_FF, D*D_FF, "#8BC34A"), ("FFN Down (d=1024)", 10, D, D_FF*D, "#7CB342"), ("⊕ Residual + RMSNorm", 10, D, D, "#E91E63"), ("Layer Output (d=1024)", 10, D, 0, "#2196F3"), ] n = len(sub_layers) y_positions = np.linspace(0, n * 3, n) all_pos = [] for i, (name, n_nodes, actual, params, chex) in enumerate(sub_layers): y = y_positions[i] rgb = hex_to_rgb(chex) spread = min(n_nodes * 0.45, 5.5) xs = np.linspace(-spread, spread, n_nodes) zs = -0.12 * (xs ** 2) ys = np.full_like(xs, y) all_pos.append((xs, ys, zs)) # Connections if i > 0: pxs, pys, pzs = all_pos[i - 1] sp = max(1, len(pxs) // 8) sc = max(1, len(xs) // 8) lines = [] cols = [] for pi in range(0, len(pxs), sp): for ci in range(0, len(xs), sc): lines.append([(pxs[pi], pys[pi], pzs[pi]), (xs[ci], ys[ci], zs[ci])]) cols.append((*rgb, 0.15)) if lines: ax.add_collection3d(Line3DCollection(lines, colors=cols, linewidths=0.6)) # Nodes sz = 130 if n_nodes > 12 else 180 ax.scatter(xs, ys, zs, c=[chex], s=sz, alpha=0.95, edgecolors="white", linewidths=0.5, depthshade=True, zorder=5) ax.scatter(xs, ys, zs, c=[chex], s=sz * 3, alpha=0.07, edgecolors="none", depthshade=True, zorder=4) # Labels lx = xs[-1] + 1.0 ax.text(lx, y, 0, name, fontsize=9, fontweight="bold", color="#E6EDF3", fontfamily="monospace", zorder=10) if params > 0: ax.text(lx, y, -0.8, f"{params:,} params", fontsize=7, color=chex, fontfamily="monospace", fontweight="bold", zorder=10) if actual > n_nodes: ax.text(xs[-1] + 0.4, y, zs[-1], f"(+{actual-n_nodes:,})", fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10) ax.text2D(0.5, 0.96, "Single Transformer Layer — 3D View", transform=fig.transFigure, fontsize=20, fontweight="bold", color="#E6EDF3", ha="center", fontfamily="monospace") ax.text2D(0.5, 0.935, f"12,584,960 params/layer × 24 layers = 302,039,040 total", transform=fig.transFigure, fontsize=10, color="#8B949E", ha="center", fontfamily="monospace") all_x = np.concatenate([p[0] for p in all_pos]) all_y = np.concatenate([p[1] for p in all_pos]) all_z = np.concatenate([p[2] for p in all_pos]) ax.set_xlim(all_x.min() - 2, all_x.max() + 8) ax.set_ylim(all_y.min() - 2, all_y.max() + 2) ax.set_zlim(all_z.min() - 2, all_z.max() + 2) plt.savefig(save_path, dpi=200, bbox_inches="tight", facecolor="#0a0e17", edgecolor="none") print(f"Saved: {save_path}") plt.close() def generate_3d_rotating_views(base_path="viz"): """Generate multiple angle views.""" import os os.makedirs(base_path, exist_ok=True) # Main dramatic angle — more front-facing generate_3d_network(f"{base_path}/nn_3d_main.png", elev=12, azim=-15) # Angled view generate_3d_network(f"{base_path}/nn_3d_top.png", elev=35, azim=-25) # Side angle generate_3d_network(f"{base_path}/nn_3d_side.png", elev=8, azim=-45) # Single layer detail generate_3d_single_layer(f"{base_path}/nn_3d_layer.png", elev=18, azim=-55) if __name__ == "__main__": import os os.makedirs("viz", exist_ok=True) print("=" * 55) print(" GPT-300M • 3D Visualization Generator") print("=" * 55) print(f" Total parameters: {TOTAL:,}") print(f" Per layer: {layer_p:,}") print(f" Layers: {N_LAYERS}") print("=" * 55) generate_3d_rotating_views("viz") print("\nAll 3D views generated!")