i_like_purple / visual_nn_3d.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M 3D Neural Network Visualization
==========================================
A 3D node-and-connection neural network diagram with depth,
perspective, and accurate parameter counts.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import numpy as np
# ═══════════════════════════════════════════════════════════════════════
# ACCURATE GPT-300M PARAMETERS
# ═══════════════════════════════════════════════════════════════════════
VOCAB = 32_000
D = 1_024
HEADS = 16
HEAD_D = 64
D_FF = 4_096
N_LAYERS = 24
embed_p = VOCAB * D # 32,768,000
attn_p = 4 * D * D # 4,194,304 per layer
ffn_p = 2 * D * D_FF # 8,388,608 per layer
norm_p = 2 * D # 2,048 per layer
layer_p = attn_p + ffn_p + norm_p # 12,584,960 per layer
all_layers_p = layer_p * N_LAYERS # 302,039,040
final_norm_p = D # 1,024
TOTAL = embed_p + all_layers_p + final_norm_p # 334,808,064
# Layer definitions: (name, num_display_nodes, actual_neurons, params, color_hex)
LAYERS = [
("Input Tokens", 10, VOCAB, 0, "#4CAF50"),
("Token Embedding", 12, D, embed_p, "#2196F3"),
("RoPE Positions", 12, D, 0, "#00BCD4"),
("Layer 1: Attention QKV", 14, D, attn_p * 3 // 4, "#FF9800"),
("Layer 1: Attention Out", 12, D, attn_p * 1 // 4, "#FF9800"),
("Layer 1: FFN Up (GELU)", 16, D_FF, ffn_p // 2, "#8BC34A"),
("Layer 1: FFN Down", 12, D, ffn_p // 2, "#8BC34A"),
("Layers 2–23 (Γ—22)", 14, D, layer_p * 22, "#9C27B0"),
("Layer 24: Attention", 14, D, attn_p, "#FF5722"),
("Layer 24: FFN", 16, D_FF, ffn_p, "#009688"),
("Layer 24: Norm + Out", 12, D, norm_p + final_norm_p, "#E91E63"),
("LM Head (weight-tied)", 12, VOCAB, 0, "#F44336"),
("Output Probabilities", 1, VOCAB, 0, "#FF1744"),
]
def hex_to_rgb(h):
h = h.lstrip("#")
return tuple(int(h[i:i+2], 16) / 255.0 for i in (0, 2, 4))
def generate_3d_network(save_path="neural_network_3d.png", elev=22, azim=-65):
"""Generate a 3D neural network with nodes, connections, and parameter labels."""
fig = plt.figure(figsize=(28, 28), facecolor="#0a0e17")
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
# Dark theme for 3D axes
ax.set_facecolor("#0a0e17")
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor("#0a0e17")
ax.yaxis.pane.set_edgecolor("#0a0e17")
ax.zaxis.pane.set_edgecolor("#0a0e17")
ax.grid(False)
ax.set_axis_off()
ax.view_init(elev=elev, azim=azim)
n_layers = len(LAYERS)
y_positions = np.linspace(0, n_layers * 4.0, n_layers) # depth (layer position)
all_positions = [] # list of (xs, ys_unused, zs, y_layer)
running_params = 0
for i, (name, n_nodes, actual, params, color_hex) in enumerate(LAYERS):
y = y_positions[i]
running_params += params
rgb = hex_to_rgb(color_hex)
# Arrange nodes in a circle/arc for 3D effect
if n_nodes == 1:
xs = np.array([0.0])
zs = np.array([0.0])
else:
# Spread nodes along x
spread = min(n_nodes * 0.5, 7.0)
xs = np.linspace(-spread, spread, n_nodes)
# Slight arc for 3D depth perception
zs = -0.1 * (xs ** 2)
ys = np.full_like(xs, y)
all_positions.append((xs, ys, zs))
# ── Draw connections to previous layer ──────────────────
if i > 0:
prev_xs, prev_ys, prev_zs = all_positions[i - 1]
# Sample connections to avoid clutter
n_prev = len(prev_xs)
n_curr = len(xs)
step_p = max(1, n_prev // 8)
step_c = max(1, n_curr // 8)
lines = []
colors_lines = []
for pi in range(0, n_prev, step_p):
for ci in range(0, n_curr, step_c):
lines.append([
(prev_xs[pi], prev_ys[pi], prev_zs[pi]),
(xs[ci], ys[ci], zs[ci]),
])
colors_lines.append((*rgb, 0.18))
if lines:
lc = Line3DCollection(lines, colors=colors_lines, linewidths=0.7)
ax.add_collection3d(lc)
# ── Draw nodes (spheres) ────────────────────────────────
node_size = 200 if n_nodes > 12 else 280
if n_nodes == 1:
node_size = 600
ax.scatter(
xs, ys, zs,
c=[color_hex], s=node_size,
alpha=0.95, edgecolors="white", linewidths=0.5,
depthshade=True, zorder=5,
)
# ── Glow effect (larger transparent scatter behind) ─────
ax.scatter(
xs, ys, zs,
c=[color_hex], s=node_size * 3,
alpha=0.08, edgecolors="none",
depthshade=True, zorder=4,
)
# ── Labels ──────────────────────────────────────────────
label_x = xs[-1] + 1.8 if n_nodes > 1 else 2.0
ax.text(
label_x, y, 0,
name,
fontsize=9.5, fontweight="bold",
color="#E6EDF3", fontfamily="monospace",
zorder=10,
)
# Param count
if params > 0:
if params >= 1_000_000:
ptxt = f"{params/1e6:.1f}M params"
else:
ptxt = f"{params:,} params"
ax.text(
label_x, y, -1.0,
ptxt,
fontsize=8, color=color_hex,
fontfamily="monospace", fontweight="bold",
zorder=10,
)
# Running total
if running_params > 0:
ax.text(
label_x, y, -1.8,
f"Ξ£ {running_params/1e6:.1f}M",
fontsize=6, color="#8B949E",
fontfamily="monospace",
zorder=10,
)
# Overflow indicator
if actual > n_nodes and n_nodes > 1:
ax.text(
xs[-1] + 0.5, y, zs[-1],
f"(+{actual - n_nodes:,})",
fontsize=6, color="#8B949E",
fontfamily="monospace",
zorder=10,
)
# ── Title ──────────────────────────────────────────────────────
ax.text2D(
0.5, 0.96,
"GPT-300M β€’ 3D Neural Network Architecture",
transform=fig.transFigure,
fontsize=22, fontweight="bold", color="#E6EDF3",
ha="center", fontfamily="monospace",
)
ax.text2D(
0.5, 0.94,
f"{TOTAL:,} parameters | {N_LAYERS} layers | {HEADS} heads | d_model={D} | d_ff={D_FF}",
transform=fig.transFigure,
fontsize=10, color="#8B949E",
ha="center", fontfamily="monospace",
)
# ── Parameter summary ──────────────────────────────────────────
summary = (
f"Parameter Breakdown:\n"
f" Embedding: {embed_p/1e6:>7.1f}M ({embed_p/TOTAL*100:.1f}%)\n"
f" Attention Γ—24: {attn_p*N_LAYERS/1e6:>7.1f}M ({attn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
f" FFN Γ—24: {ffn_p*N_LAYERS/1e6:>7.1f}M ({ffn_p*N_LAYERS/TOTAL*100:.1f}%)\n"
f" Norms: {(norm_p*N_LAYERS+final_norm_p)/1e6:>7.3f}M ({(norm_p*N_LAYERS+final_norm_p)/TOTAL*100:.1f}%)\n"
f" LM Head: tied (0 extra)\n"
f" ───────────────────────\n"
f" TOTAL: {TOTAL/1e6:>7.1f}M"
)
ax.text2D(
0.02, 0.06, summary,
transform=fig.transFigure,
fontsize=8, color="#58A6FF",
fontfamily="monospace", verticalalignment="bottom",
bbox=dict(boxstyle="round,pad=0.6", facecolor="#161B22",
edgecolor="#30363D", linewidth=1),
)
# ── Legend ──────────────────────────────────────────────────────
legend_items = [
("#4CAF50", "Input"), ("#2196F3", "Embeddings"), ("#FF9800", "Attention"),
("#8BC34A", "FFN"), ("#9C27B0", "Γ—22 Layers"), ("#E91E63", "Norm"),
("#F44336", "Output"),
]
for j, (c, l) in enumerate(legend_items):
ax.text2D(
0.92, 0.30 - j * 0.025, f"● {l}",
transform=fig.transFigure,
fontsize=8, color=c, fontfamily="monospace",
)
# Set axis limits
all_x = np.concatenate([p[0] for p in all_positions])
all_y = np.concatenate([p[1] for p in all_positions])
all_z = np.concatenate([p[2] for p in all_positions])
margin = 4
ax.set_xlim(all_x.min() - margin, all_x.max() + margin + 8)
ax.set_ylim(all_y.min() - margin, all_y.max() + margin)
ax.set_zlim(all_z.min() - margin, all_z.max() + margin)
plt.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor="#0a0e17", edgecolor="none")
print(f"Saved: {save_path}")
plt.close()
def generate_3d_single_layer(save_path="layer_3d.png", elev=18, azim=-55):
"""3D view of a single transformer layer internals."""
fig = plt.figure(figsize=(22, 18), facecolor="#0a0e17")
ax = fig.add_subplot(111, projection="3d", computed_zorder=False)
ax.set_facecolor("#0a0e17")
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor("#0a0e17")
ax.yaxis.pane.set_edgecolor("#0a0e17")
ax.zaxis.pane.set_edgecolor("#0a0e17")
ax.grid(False)
ax.set_axis_off()
ax.view_init(elev=elev, azim=azim)
sub_layers = [
("Input (d=1024)", 10, D, 0, "#2196F3"),
("Query (d=1024)", 10, D, D*D, "#FF6B6B"),
("Key (d=1024)", 10, D, D*D, "#4ECDC4"),
("Value (d=1024)", 10, D, D*D, "#45B7D1"),
("16 Attention Heads", 16, D, 0, "#FF9800"),
("Attn Output (d=1024)", 10, D, D*D, "#FFA726"),
("βŠ• Residual + RMSNorm", 10, D, D, "#E91E63"),
("FFN Up β†’ GELU (d=4096)", 16, D_FF, D*D_FF, "#8BC34A"),
("FFN Down (d=1024)", 10, D, D_FF*D, "#7CB342"),
("βŠ• Residual + RMSNorm", 10, D, D, "#E91E63"),
("Layer Output (d=1024)", 10, D, 0, "#2196F3"),
]
n = len(sub_layers)
y_positions = np.linspace(0, n * 3, n)
all_pos = []
for i, (name, n_nodes, actual, params, chex) in enumerate(sub_layers):
y = y_positions[i]
rgb = hex_to_rgb(chex)
spread = min(n_nodes * 0.45, 5.5)
xs = np.linspace(-spread, spread, n_nodes)
zs = -0.12 * (xs ** 2)
ys = np.full_like(xs, y)
all_pos.append((xs, ys, zs))
# Connections
if i > 0:
pxs, pys, pzs = all_pos[i - 1]
sp = max(1, len(pxs) // 8)
sc = max(1, len(xs) // 8)
lines = []
cols = []
for pi in range(0, len(pxs), sp):
for ci in range(0, len(xs), sc):
lines.append([(pxs[pi], pys[pi], pzs[pi]), (xs[ci], ys[ci], zs[ci])])
cols.append((*rgb, 0.15))
if lines:
ax.add_collection3d(Line3DCollection(lines, colors=cols, linewidths=0.6))
# Nodes
sz = 130 if n_nodes > 12 else 180
ax.scatter(xs, ys, zs, c=[chex], s=sz, alpha=0.95,
edgecolors="white", linewidths=0.5, depthshade=True, zorder=5)
ax.scatter(xs, ys, zs, c=[chex], s=sz * 3, alpha=0.07,
edgecolors="none", depthshade=True, zorder=4)
# Labels
lx = xs[-1] + 1.0
ax.text(lx, y, 0, name, fontsize=9, fontweight="bold",
color="#E6EDF3", fontfamily="monospace", zorder=10)
if params > 0:
ax.text(lx, y, -0.8, f"{params:,} params",
fontsize=7, color=chex, fontfamily="monospace",
fontweight="bold", zorder=10)
if actual > n_nodes:
ax.text(xs[-1] + 0.4, y, zs[-1], f"(+{actual-n_nodes:,})",
fontsize=6, color="#8B949E", fontfamily="monospace", zorder=10)
ax.text2D(0.5, 0.96, "Single Transformer Layer β€” 3D View",
transform=fig.transFigure, fontsize=20, fontweight="bold",
color="#E6EDF3", ha="center", fontfamily="monospace")
ax.text2D(0.5, 0.935,
f"12,584,960 params/layer Γ— 24 layers = 302,039,040 total",
transform=fig.transFigure, fontsize=10, color="#8B949E",
ha="center", fontfamily="monospace")
all_x = np.concatenate([p[0] for p in all_pos])
all_y = np.concatenate([p[1] for p in all_pos])
all_z = np.concatenate([p[2] for p in all_pos])
ax.set_xlim(all_x.min() - 2, all_x.max() + 8)
ax.set_ylim(all_y.min() - 2, all_y.max() + 2)
ax.set_zlim(all_z.min() - 2, all_z.max() + 2)
plt.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor="#0a0e17", edgecolor="none")
print(f"Saved: {save_path}")
plt.close()
def generate_3d_rotating_views(base_path="viz"):
"""Generate multiple angle views."""
import os
os.makedirs(base_path, exist_ok=True)
# Main dramatic angle β€” more front-facing
generate_3d_network(f"{base_path}/nn_3d_main.png", elev=12, azim=-15)
# Angled view
generate_3d_network(f"{base_path}/nn_3d_top.png", elev=35, azim=-25)
# Side angle
generate_3d_network(f"{base_path}/nn_3d_side.png", elev=8, azim=-45)
# Single layer detail
generate_3d_single_layer(f"{base_path}/nn_3d_layer.png", elev=18, azim=-55)
if __name__ == "__main__":
import os
os.makedirs("viz", exist_ok=True)
print("=" * 55)
print(" GPT-300M β€’ 3D Visualization Generator")
print("=" * 55)
print(f" Total parameters: {TOTAL:,}")
print(f" Per layer: {layer_p:,}")
print(f" Layers: {N_LAYERS}")
print("=" * 55)
generate_3d_rotating_views("viz")
print("\nAll 3D views generated!")