| | """ |
| | GPT-300M 3D Neural Network Visualization |
| | ========================================== |
| | A 3D node-and-connection neural network diagram with depth, |
| | perspective, and accurate parameter counts. |
| | """ |
| |
|
| | import matplotlib |
| | matplotlib.use("Agg") |
| |
|
| | import matplotlib.pyplot as plt |
| | from mpl_toolkits.mplot3d import Axes3D |
| | from mpl_toolkits.mplot3d.art3d import Line3DCollection |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| |
|
| | VOCAB = 32_000 |
| | D = 1_024 |
| | HEADS = 16 |
| | HEAD_D = 64 |
| | D_FF = 4_096 |
| | N_LAYERS = 24 |
| |
|
| | embed_p = VOCAB * D |
| | attn_p = 4 * D * D |
| | ffn_p = 2 * D * D_FF |
| | norm_p = 2 * D |
| | layer_p = attn_p + ffn_p + norm_p |
| | all_layers_p = layer_p * N_LAYERS |
| | final_norm_p = D |
| | TOTAL = embed_p + all_layers_p + final_norm_p |
| |
|
| | |
| | LAYERS = [ |
| | ("Input Tokens", 10, VOCAB, 0, "#4CAF50"), |
| | ("Token Embedding", 12, D, embed_p, "#2196F3"), |
| | ("RoPE Positions", 12, D, 0, "#00BCD4"), |
| | ("Layer 1: Attention QKV", 14, D, attn_p * 3 // 4, "#FF9800"), |
| | ("Layer 1: Attention Out", 12, D, attn_p * 1 // 4, "#FF9800"), |
| | ("Layer 1: FFN Up (GELU)", 16, D_FF, ffn_p // 2, "#8BC34A"), |
| | ("Layer 1: FFN Down", 12, D, ffn_p // 2, "#8BC34A"), |
| | ("Layers 2β23 (Γ22)", 14, D, layer_p * 22, "#9C27B0"), |
| | ("Layer 24: Attention", 14, D, attn_p, "#FF5722"), |
| | ("Layer 24: FFN", 16, D_FF, ffn_p, "#009688"), |
| | ("Layer 24: Norm + Out", 12, D, norm_p + final_norm_p, "#E91E63"), |
| | ("LM Head (weight-tied)", 12, VOCAB, 0, "#F44336"), |
| | ("Output Probabilities", 1, VOCAB, 0, "#FF1744"), |
| | ] |
| |
|
| |
|
| | def hex_to_rgb(h): |
| | h = h.lstrip("#") |
| | return tuple(int(h[i:i+2], 16) / 255.0 for i in (0, 2, 4)) |
| |
|
| |
|
| | def generate_3d_network(save_path="neural_network_3d.png", elev=22, azim=-65): |
| | """Generate a 3D neural network with nodes, connections, and parameter labels.""" |
| |
|
| | fig = plt.figure(figsize=(28, 28), facecolor="#0a0e17") |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| |
|
| | |
| | ax.set_facecolor("#0a0e17") |
| | ax.xaxis.pane.fill = False |
| | ax.yaxis.pane.fill = False |
| | ax.zaxis.pane.fill = False |
| | ax.xaxis.pane.set_edgecolor("#0a0e17") |
| | ax.yaxis.pane.set_edgecolor("#0a0e17") |
| | ax.zaxis.pane.set_edgecolor("#0a0e17") |
| | ax.grid(False) |
| | ax.set_axis_off() |
| |
|
| | ax.view_init(elev=elev, azim=azim) |
| |
|
| | n_layers = len(LAYERS) |
| | y_positions = np.linspace(0, n_layers * 4.0, n_layers) |
| |
|
| | all_positions = [] |
| | running_params = 0 |
| |
|
| | for i, (name, n_nodes, actual, params, color_hex) in enumerate(LAYERS): |
| | y = y_positions[i] |
| | running_params += params |
| |
|
| | rgb = hex_to_rgb(color_hex) |
| |
|
| | |
| | if n_nodes == 1: |
| | xs = np.array([0.0]) |
| | zs = np.array([0.0]) |
| | else: |
| | |
| | spread = min(n_nodes * 0.5, 7.0) |
| | xs = np.linspace(-spread, spread, n_nodes) |
| | |
| | zs = -0.1 * (xs ** 2) |
| |
|
| | ys = np.full_like(xs, y) |
| | all_positions.append((xs, ys, zs)) |
| |
|
| | |
| | if i > 0: |
| | prev_xs, prev_ys, prev_zs = all_positions[i - 1] |
| |
|
| | |
| | n_prev = len(prev_xs) |
| | n_curr = len(xs) |
| | step_p = max(1, n_prev // 8) |
| | step_c = max(1, n_curr // 8) |
| |
|
| | lines = [] |
| | colors_lines = [] |
| | for pi in range(0, n_prev, step_p): |
| | for ci in range(0, n_curr, step_c): |
| | lines.append([ |
| | (prev_xs[pi], prev_ys[pi], prev_zs[pi]), |
| | (xs[ci], ys[ci], zs[ci]), |
| | ]) |
| | colors_lines.append((*rgb, 0.18)) |
| |
|
| | if lines: |
| | lc = Line3DCollection(lines, colors=colors_lines, linewidths=0.7) |
| | ax.add_collection3d(lc) |
| |
|
| | |
| | node_size = 200 if n_nodes > 12 else 280 |
| | if n_nodes == 1: |
| | node_size = 600 |
| |
|
| | ax.scatter( |
| | xs, ys, zs, |
| | c=[color_hex], s=node_size, |
| | alpha=0.95, edgecolors="white", linewidths=0.5, |
| | depthshade=True, zorder=5, |
| | ) |
| |
|
| | |
| | ax.scatter( |
| | xs, ys, zs, |
| | c=[color_hex], s=node_size * 3, |
| | alpha=0.08, edgecolors="none", |
| | depthshade=True, zorder=4, |
| | ) |
| |
|
| | |
| | label_x = xs[-1] + 1.8 if n_nodes > 1 else 2.0 |
| | ax.text( |
| | label_x, y, 0, |
| | name, |
| | fontsize=9.5, fontweight="bold", |
| | color="#E6EDF3", fontfamily="monospace", |
| | zorder=10, |
| | ) |
| |
|
| | |
| | if params > 0: |
| | if params >= 1_000_000: |
| | ptxt = f"{params/1e6:.1f}M params" |
| | else: |
| | ptxt = f"{params:,} params" |
| | ax.text( |
| | label_x, y, -1.0, |
| | ptxt, |
| | fontsize=8, color=color_hex, |
| | fontfamily="monospace", fontweight="bold", |
| | zorder=10, |
| | ) |
| |
|
| | |
| | if running_params > 0: |
| | ax.text( |
| | label_x, y, -1.8, |
| | f"Ξ£ {running_params/1e6:.1f}M", |
| | fontsize=6, color="#8B949E", |
| | fontfamily="monospace", |
| | zorder=10, |
| | ) |
| |
|
| | |
| | if actual > n_nodes and n_nodes > 1: |
| | ax.text( |
| | xs[-1] + 0.5, y, zs[-1], |
| | f"(+{actual - n_nodes:,})", |
| | fontsize=6, color="#8B949E", |
| | fontfamily="monospace", |
| | zorder=10, |
| | ) |
| |
|
| | |
| | ax.text2D( |
| | 0.5, 0.96, |
| | "GPT-300M β’ 3D Neural Network Architecture", |
| | transform=fig.transFigure, |
| | fontsize=22, fontweight="bold", color="#E6EDF3", |
| | ha="center", fontfamily="monospace", |
| | ) |
| | ax.text2D( |
| | 0.5, 0.94, |
| | f"{TOTAL:,} parameters | {N_LAYERS} layers | {HEADS} heads | d_model={D} | d_ff={D_FF}", |
| | transform=fig.transFigure, |
| | fontsize=10, color="#8B949E", |
| | ha="center", fontfamily="monospace", |
| | ) |
| |
|
| | |
| | summary = ( |
| | f"Parameter Breakdown:\n" |
| | f" Embedding: {embed_p/1e6:>7.1f}M ({embed_p/TOTAL*100:.1f}%)\n" |
| | f" Attention Γ24: {attn_p*N_LAYERS/1e6:>7.1f}M ({attn_p*N_LAYERS/TOTAL*100:.1f}%)\n" |
| | f" FFN Γ24: {ffn_p*N_LAYERS/1e6:>7.1f}M ({ffn_p*N_LAYERS/TOTAL*100:.1f}%)\n" |
| | f" Norms: {(norm_p*N_LAYERS+final_norm_p)/1e6:>7.3f}M ({(norm_p*N_LAYERS+final_norm_p)/TOTAL*100:.1f}%)\n" |
| | f" LM Head: tied (0 extra)\n" |
| | f" βββββββββββββββββββββββ\n" |
| | f" TOTAL: {TOTAL/1e6:>7.1f}M" |
| | ) |
| | ax.text2D( |
| | 0.02, 0.06, summary, |
| | transform=fig.transFigure, |
| | fontsize=8, color="#58A6FF", |
| | fontfamily="monospace", verticalalignment="bottom", |
| | bbox=dict(boxstyle="round,pad=0.6", facecolor="#161B22", |
| | edgecolor="#30363D", linewidth=1), |
| | ) |
| |
|
| | |
| | legend_items = [ |
| | ("#4CAF50", "Input"), ("#2196F3", "Embeddings"), ("#FF9800", "Attention"), |
| | ("#8BC34A", "FFN"), ("#9C27B0", "Γ22 Layers"), ("#E91E63", "Norm"), |
| | ("#F44336", "Output"), |
| | ] |
| | for j, (c, l) in enumerate(legend_items): |
| | ax.text2D( |
| | 0.92, 0.30 - j * 0.025, f"β {l}", |
| | transform=fig.transFigure, |
| | fontsize=8, color=c, fontfamily="monospace", |
| | ) |
| |
|
| | |
| | all_x = np.concatenate([p[0] for p in all_positions]) |
| | all_y = np.concatenate([p[1] for p in all_positions]) |
| | all_z = np.concatenate([p[2] for p in all_positions]) |
| | margin = 4 |
| | ax.set_xlim(all_x.min() - margin, all_x.max() + margin + 8) |
| | ax.set_ylim(all_y.min() - margin, all_y.max() + margin) |
| | ax.set_zlim(all_z.min() - margin, all_z.max() + margin) |
| |
|
| | plt.savefig(save_path, dpi=200, bbox_inches="tight", |
| | facecolor="#0a0e17", edgecolor="none") |
| | print(f"Saved: {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def generate_3d_single_layer(save_path="layer_3d.png", elev=18, azim=-55): |
| | """3D view of a single transformer layer internals.""" |
| |
|
| | fig = plt.figure(figsize=(22, 18), facecolor="#0a0e17") |
| | ax = fig.add_subplot(111, projection="3d", computed_zorder=False) |
| |
|
| | ax.set_facecolor("#0a0e17") |
| | ax.xaxis.pane.fill = False |
| | ax.yaxis.pane.fill = False |
| | ax.zaxis.pane.fill = False |
| | ax.xaxis.pane.set_edgecolor("#0a0e17") |
| | ax.yaxis.pane.set_edgecolor("#0a0e17") |
| | ax.zaxis.pane.set_edgecolor("#0a0e17") |
| | ax.grid(False) |
| | ax.set_axis_off() |
| | ax.view_init(elev=elev, azim=azim) |
| |
|
| | sub_layers = [ |
| | ("Input (d=1024)", 10, D, 0, "#2196F3"), |
| | ("Query (d=1024)", 10, D, D*D, "#FF6B6B"), |
| | ("Key (d=1024)", 10, D, D*D, "#4ECDC4"), |
| | ("Value (d=1024)", 10, D, D*D, "#45B7D1"), |
| | ("16 Attention Heads", 16, D, 0, "#FF9800"), |
| | ("Attn Output (d=1024)", 10, D, D*D, "#FFA726"), |
| | ("β Residual + RMSNorm", 10, D, D, "#E91E63"), |
| | ("FFN Up β GELU (d=4096)", 16, D_FF, D*D_FF, "#8BC34A"), |
| | ("FFN Down (d=1024)", 10, D, D_FF*D, "#7CB342"), |
| | ("β Residual + RMSNorm", 10, D, D, "#E91E63"), |
| | ("Layer Output (d=1024)", 10, D, 0, "#2196F3"), |
| | ] |
| |
|
| | n = len(sub_layers) |
| | y_positions = np.linspace(0, n * 3, n) |
| | all_pos = [] |
| |
|
| | for i, (name, n_nodes, actual, params, chex) in enumerate(sub_layers): |
| | y = y_positions[i] |
| | rgb = hex_to_rgb(chex) |
| |
|
| | spread = min(n_nodes * 0.45, 5.5) |
| | xs = np.linspace(-spread, spread, n_nodes) |
| | zs = -0.12 * (xs ** 2) |
| | ys = np.full_like(xs, y) |
| | all_pos.append((xs, ys, zs)) |
| |
|
| | |
| | if i > 0: |
| | pxs, pys, pzs = all_pos[i - 1] |
| | sp = max(1, len(pxs) // 8) |
| | sc = max(1, len(xs) // 8) |
| | lines = [] |
| | cols = [] |
| | for pi in range(0, len(pxs), sp): |
| | for ci in range(0, len(xs), sc): |
| | lines.append([(pxs[pi], pys[pi], pzs[pi]), (xs[ci], ys[ci], zs[ci])]) |
| | cols.append((*rgb, 0.15)) |
| | if lines: |
| | ax.add_collection3d(Line3DCollection(lines, colors=cols, linewidths=0.6)) |
| |
|
| | |
| | sz = 130 if n_nodes > 12 else 180 |
| | ax.scatter(xs, ys, zs, c=[chex], s=sz, alpha=0.95, |
| | edgecolors="white", linewidths=0.5, depthshade=True, zorder=5) |
| | ax.scatter(xs, ys, zs, c=[chex], s=sz * 3, alpha=0.07, |
| | edgecolors="none", depthshade=True, zorder=4) |
| |
|
| | |
| | lx = xs[-1] + 1.0 |
| | ax.text(lx, y, 0, name, fontsize=9, fontweight="bold", |
| | color="#E6EDF3", fontfamily="monospace", zorder=10) |
| | if params > 0: |
| | ax.text(lx, y, -0.8, f"{params:,} params", |
| | fontsize=7, color=chex, fontfamily="monospace", |
| | fontweight="bold", zorder=10) |
| |
|
| | if actual > n_nodes: |
| | ax.text(xs[-1] + 0.4, y, zs[-1], f"(+{actual-n_nodes:,})", |
| | fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10) |
| |
|
| | ax.text2D(0.5, 0.96, "Single Transformer Layer β 3D View", |
| | transform=fig.transFigure, fontsize=20, fontweight="bold", |
| | color="#E6EDF3", ha="center", fontfamily="monospace") |
| | ax.text2D(0.5, 0.935, |
| | f"12,584,960 params/layer Γ 24 layers = 302,039,040 total", |
| | transform=fig.transFigure, fontsize=10, color="#8B949E", |
| | ha="center", fontfamily="monospace") |
| |
|
| | all_x = np.concatenate([p[0] for p in all_pos]) |
| | all_y = np.concatenate([p[1] for p in all_pos]) |
| | all_z = np.concatenate([p[2] for p in all_pos]) |
| | ax.set_xlim(all_x.min() - 2, all_x.max() + 8) |
| | ax.set_ylim(all_y.min() - 2, all_y.max() + 2) |
| | ax.set_zlim(all_z.min() - 2, all_z.max() + 2) |
| |
|
| | plt.savefig(save_path, dpi=200, bbox_inches="tight", |
| | facecolor="#0a0e17", edgecolor="none") |
| | print(f"Saved: {save_path}") |
| | plt.close() |
| |
|
| |
|
| | def generate_3d_rotating_views(base_path="viz"): |
| | """Generate multiple angle views.""" |
| | import os |
| | os.makedirs(base_path, exist_ok=True) |
| |
|
| | |
| | generate_3d_network(f"{base_path}/nn_3d_main.png", elev=12, azim=-15) |
| |
|
| | |
| | generate_3d_network(f"{base_path}/nn_3d_top.png", elev=35, azim=-25) |
| |
|
| | |
| | generate_3d_network(f"{base_path}/nn_3d_side.png", elev=8, azim=-45) |
| |
|
| | |
| | generate_3d_single_layer(f"{base_path}/nn_3d_layer.png", elev=18, azim=-55) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import os |
| | os.makedirs("viz", exist_ok=True) |
| |
|
| | print("=" * 55) |
| | print(" GPT-300M β’ 3D Visualization Generator") |
| | print("=" * 55) |
| | print(f" Total parameters: {TOTAL:,}") |
| | print(f" Per layer: {layer_p:,}") |
| | print(f" Layers: {N_LAYERS}") |
| | print("=" * 55) |
| |
|
| | generate_3d_rotating_views("viz") |
| | print("\nAll 3D views generated!") |
| |
|