i_like_purple / visualize_nn.py
dasdasddds's picture
Upload 16 files
93783dd verified
"""
GPT-300M Neural Network Visualizer
====================================
Generates detailed architectural diagrams of the GPT-300M model
using matplotlib, showing:
- Full model architecture flow
- Detailed transformer block internals
- Attention head visualization
- Parameter distribution charts
Usage:
python visualize_nn.py
python visualize_nn.py --output architecture.png
"""
import argparse
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
import numpy as np
from config import GPT300MConfig, gpt_300m
# ═══════════════════════════════════════════════════════════════════════
# COLOR SCHEME
# ═══════════════════════════════════════════════════════════════════════
COLORS = {
"bg": "#0D1117",
"text": "#E6EDF3",
"text_dim": "#8B949E",
"embed": "#58A6FF", # Blue
"attn": "#F78166", # Orange
"ffn": "#7EE787", # Green
"norm": "#D2A8FF", # Purple
"residual": "#FFA657", # Yellow-orange
"output": "#FF7B72", # Red
"arrow": "#484F58",
"highlight": "#1F6FEB",
"border": "#30363D",
"card_bg": "#161B22",
"accent1": "#79C0FF",
"accent2": "#BB9AF7",
}
def draw_rounded_box(ax, x, y, w, h, color, label, fontsize=10,
text_color=None, alpha=0.9, sublabel=None):
"""Draw a rounded rectangle with label."""
box = FancyBboxPatch(
(x - w/2, y - h/2), w, h,
boxstyle="round,pad=0.1",
facecolor=color,
edgecolor="white",
linewidth=0.5,
alpha=alpha,
zorder=3,
)
ax.add_patch(box)
ax.text(
x, y + (0.15 if sublabel else 0),
label,
ha="center", va="center",
fontsize=fontsize,
fontweight="bold",
color=text_color or COLORS["text"],
zorder=4,
)
if sublabel:
ax.text(
x, y - 0.25,
sublabel,
ha="center", va="center",
fontsize=fontsize - 2,
color=COLORS["text_dim"],
zorder=4,
)
def draw_arrow(ax, x1, y1, x2, y2, color=None):
"""Draw an arrow between two points."""
ax.annotate(
"",
xy=(x2, y2), xytext=(x1, y1),
arrowprops=dict(
arrowstyle="->",
color=color or COLORS["arrow"],
lw=1.5,
connectionstyle="arc3,rad=0",
),
zorder=2,
)
def draw_residual_connection(ax, x_start, y_start, x_end, y_end, offset=1.8):
"""Draw a residual/skip connection arc."""
ax.annotate(
"",
xy=(x_end, y_end), xytext=(x_start, y_start),
arrowprops=dict(
arrowstyle="->",
color=COLORS["residual"],
lw=1.2,
linestyle="--",
connectionstyle=f"arc3,rad=0.3",
),
zorder=1,
)
# ═══════════════════════════════════════════════════════════════════════
# FULL ARCHITECTURE DIAGRAM
# ═══════════════════════════════════════════════════════════════════════
def draw_full_architecture(config: GPT300MConfig, save_path: str = None):
"""Draw the complete GPT-300M architecture."""
fig, ax = plt.subplots(1, 1, figsize=(14, 24), facecolor=COLORS["bg"])
ax.set_facecolor(COLORS["bg"])
ax.set_xlim(-4, 4)
ax.set_ylim(-1, 22)
ax.axis("off")
# Title
ax.text(0, 21.5, "GPT-300M Architecture", ha="center", va="center",
fontsize=22, fontweight="bold", color=COLORS["text"],
fontfamily="monospace")
ax.text(0, 21.0,
f"{config.total_params_estimate:,} parameters β€’ "
f"{config.n_layers} layers β€’ "
f"{config.n_heads} heads β€’ "
f"d={config.d_model}",
ha="center", va="center", fontsize=10, color=COLORS["text_dim"],
fontfamily="monospace")
y = 19.5 # Starting y position
# ── Input ──────────────────────────────────────────────────────
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["card_bg"], "Input Token IDs",
sublabel=f"[batch, seq_len]", fontsize=11)
y -= 1.1
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
# ── Token Embedding ────────────────────────────────────────────
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["embed"],
"Token Embedding", text_color="#000",
sublabel=f"{config.vocab_size:,} Γ— {config.d_model}")
y -= 1.1
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4)
# ── RoPE ───────────────────────────────────────────────────────
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["accent2"],
"Rotary Position Embeddings (RoPE)",
text_color="#000", fontsize=9,
sublabel=f"ΞΈ = {config.rope_theta:.0f}")
y -= 1.0
draw_arrow(ax, 0, y + 0.7, 0, y + 0.4)
# ── Dropout ────────────────────────────────────────────────────
draw_rounded_box(ax, 0, y, 2.5, 0.5, COLORS["border"],
f"Dropout (p={config.dropout})", fontsize=9)
y -= 1.0
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
# ── Transformer Blocks ─────────────────────────────────────────
block_height = 3.2
# Draw detailed first block
block_y_start = y
block_y_end = y - block_height
# Block container
block_box = FancyBboxPatch(
(-3.3, block_y_end - 0.1), 6.6, block_height + 0.2,
boxstyle="round,pad=0.15",
facecolor=COLORS["card_bg"],
edgecolor=COLORS["highlight"],
linewidth=1.5,
alpha=0.8,
zorder=1,
)
ax.add_patch(block_box)
ax.text(-3.0, block_y_start + 0.05,
f"Transformer Block Γ— {config.n_layers}",
fontsize=10, fontweight="bold", color=COLORS["highlight"],
fontfamily="monospace", zorder=5)
# Inside the block
by = block_y_start - 0.4
# RMSNorm 1
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
"RMSNorm", text_color="#000", fontsize=9)
by -= 0.7
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
# Multi-Head Attention
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["attn"],
"Multi-Head Attention", text_color="#000", fontsize=10,
sublabel=f"{config.n_heads} heads Γ— {config.head_dim}d")
# Residual connection
draw_residual_connection(ax, -1.6, block_y_start - 0.2, -1.6, by)
ax.text(-2.5, by + 0.3, "βŠ• residual", fontsize=7,
color=COLORS["residual"], ha="center")
by -= 0.8
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
# RMSNorm 2
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"],
"RMSNorm", text_color="#000", fontsize=9)
by -= 0.7
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25)
# Feed-Forward Network
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["ffn"],
"Feed-Forward Network", text_color="#000", fontsize=10,
sublabel=f"{config.d_model} β†’ {config.d_ff} β†’ {config.d_model}")
# Residual connection
draw_residual_connection(ax, 1.6, by + 1.5, 1.6, by)
ax.text(2.5, by + 0.7, "βŠ• residual", fontsize=7,
color=COLORS["residual"], ha="center")
y = block_y_end - 0.4
# ── Repeated blocks indicator ──────────────────────────────────
draw_arrow(ax, 0, y + 0.2, 0, y - 0.1)
ax.text(0, y - 0.3, f"Γ— {config.n_layers} layers", ha="center",
fontsize=11, fontweight="bold", color=COLORS["text_dim"],
fontfamily="monospace",
bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS["card_bg"],
edgecolor=COLORS["border"]))
y -= 0.9
draw_arrow(ax, 0, y + 0.3, 0, y + 0.05)
# ── Final RMSNorm ──────────────────────────────────────────────
draw_rounded_box(ax, 0, y - 0.2, 3.5, 0.5, COLORS["norm"],
"Final RMSNorm", text_color="#000", fontsize=10)
y -= 1.0
draw_arrow(ax, 0, y + 0.5, 0, y + 0.2)
# ── LM Head ────────────────────────────────────────────────────
draw_rounded_box(ax, 0, y - 0.1, 3.5, 0.7, COLORS["output"],
"Linear (LM Head)", text_color="#000", fontsize=11,
sublabel=f"{config.d_model} β†’ {config.vocab_size:,} (weight-tied)")
y -= 1.1
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35)
# ── Softmax / Output ───────────────────────────────────────────
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["card_bg"],
"Softmax β†’ Next Token Probabilities", fontsize=10,
sublabel=f"[batch, seq_len, {config.vocab_size:,}]")
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor=COLORS["bg"], edgecolor="none")
print(f"Saved architecture diagram: {save_path}")
return fig
# ═══════════════════════════════════════════════════════════════════════
# PARAMETER DISTRIBUTION CHART
# ═══════════════════════════════════════════════════════════════════════
def draw_parameter_chart(config: GPT300MConfig, save_path: str = None):
"""Draw a parameter distribution breakdown."""
fig, axes = plt.subplots(1, 2, figsize=(16, 7), facecolor=COLORS["bg"])
# Calculate parameter counts per component
emb_params = config.vocab_size * config.d_model
attn_params = 4 * config.d_model * config.d_model * config.n_layers
ffn_params = 2 * config.d_model * config.d_ff * config.n_layers
norm_params = 2 * config.d_model * config.n_layers + config.d_model
total = emb_params + attn_params + ffn_params + norm_params
# ── Pie Chart ──────────────────────────────────────────────────
ax = axes[0]
ax.set_facecolor(COLORS["bg"])
labels = ["Token\nEmbedding", "Attention\nLayers", "Feed-Forward\nLayers", "LayerNorm"]
sizes = [emb_params, attn_params, ffn_params, norm_params]
colors = [COLORS["embed"], COLORS["attn"], COLORS["ffn"], COLORS["norm"]]
wedges, texts, autotexts = ax.pie(
sizes, labels=None, autopct=lambda p: f"{p:.1f}%",
colors=colors, startangle=90, pctdistance=0.7,
wedgeprops=dict(width=0.5, edgecolor=COLORS["bg"], linewidth=2),
textprops=dict(color=COLORS["text"], fontsize=10),
)
for at in autotexts:
at.set_fontweight("bold")
at.set_color("#000")
# Legend
legend_labels = [
f"{l}\n({s/1e6:.1f}M)" for l, s in zip(
["Token Embedding", "Attention", "Feed-Forward", "LayerNorm"],
sizes
)
]
ax.legend(
wedges, legend_labels, loc="center left", bbox_to_anchor=(1.05, 0.5),
fontsize=9, frameon=False, labelcolor=COLORS["text"],
)
ax.set_title("Parameter Distribution", fontsize=14, fontweight="bold",
color=COLORS["text"], pad=15)
# ── Per-Layer Breakdown Bar Chart ──────────────────────────────
ax = axes[1]
ax.set_facecolor(COLORS["bg"])
layer_attn = 4 * config.d_model * config.d_model
layer_ffn = 2 * config.d_model * config.d_ff
layer_norm = 2 * config.d_model
layers = range(1, config.n_layers + 1)
bar_width = 0.8
ax.bar(layers, [layer_attn / 1e6] * config.n_layers, bar_width,
label="Attention", color=COLORS["attn"], alpha=0.9)
ax.bar(layers, [layer_ffn / 1e6] * config.n_layers, bar_width,
bottom=[layer_attn / 1e6] * config.n_layers,
label="Feed-Forward", color=COLORS["ffn"], alpha=0.9)
ax.bar(layers, [layer_norm / 1e6] * config.n_layers, bar_width,
bottom=[(layer_attn + layer_ffn) / 1e6] * config.n_layers,
label="Norm", color=COLORS["norm"], alpha=0.9)
ax.set_xlabel("Layer", fontsize=11, color=COLORS["text"])
ax.set_ylabel("Parameters (M)", fontsize=11, color=COLORS["text"])
ax.set_title("Parameters Per Layer", fontsize=14, fontweight="bold",
color=COLORS["text"], pad=15)
ax.legend(fontsize=9, frameon=False, labelcolor=COLORS["text"])
ax.tick_params(colors=COLORS["text_dim"])
ax.spines["bottom"].set_color(COLORS["border"])
ax.spines["left"].set_color(COLORS["border"])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Overall title
fig.suptitle(
f"GPT-300M β€’ {total:,} Total Parameters",
fontsize=16, fontweight="bold", color=COLORS["text"],
fontfamily="monospace", y=1.02,
)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor=COLORS["bg"], edgecolor="none")
print(f"Saved parameter chart: {save_path}")
return fig
# ═══════════════════════════════════════════════════════════════════════
# ATTENTION HEAD VISUALIZATION
# ═══════════════════════════════════════════════════════════════════════
def draw_attention_heads(config: GPT300MConfig, save_path: str = None):
"""Visualize the multi-head attention mechanism."""
fig, ax = plt.subplots(1, 1, figsize=(14, 10), facecolor=COLORS["bg"])
ax.set_facecolor(COLORS["bg"])
ax.set_xlim(-1, 11)
ax.set_ylim(-1, 8)
ax.axis("off")
ax.text(5, 7.5, "Multi-Head Self-Attention", ha="center",
fontsize=18, fontweight="bold", color=COLORS["text"],
fontfamily="monospace")
ax.text(5, 7.0,
f"{config.n_heads} heads Γ— {config.head_dim}d per head = {config.d_model}d total",
ha="center", fontsize=10, color=COLORS["text_dim"])
# Input
draw_rounded_box(ax, 5, 6.2, 4, 0.5, COLORS["embed"],
f"Input: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
# Q, K, V projections
for i, (name, color) in enumerate(zip(["Q", "K", "V"],
["#FF6B6B", "#4ECDC4", "#45B7D1"])):
x = 2 + i * 3
draw_arrow(ax, 5, 5.9, x, 5.4)
draw_rounded_box(ax, x, 5.1, 1.8, 0.5, color,
f"W_{name}", text_color="#000", fontsize=10,
sublabel=f"{config.d_model}Γ—{config.d_model}")
# Heads
head_y = 3.8
n_show = min(config.n_heads, 8)
head_spacing = 9.0 / n_show
for h in range(n_show):
hx = 1 + h * head_spacing
# Head box
box = FancyBboxPatch(
(hx - 0.4, head_y - 0.3), 0.8, 0.6,
boxstyle="round,pad=0.05",
facecolor=COLORS["attn"],
edgecolor="white",
linewidth=0.5,
alpha=0.8,
zorder=3,
)
ax.add_patch(box)
ax.text(hx, head_y, f"H{h+1}", ha="center", va="center",
fontsize=8, fontweight="bold", color="#000", zorder=4)
# Arrows from Q,K,V to heads
for qi, qx in enumerate([2, 5, 8]):
ax.annotate("", xy=(hx, head_y + 0.3), xytext=(qx, 4.8),
arrowprops=dict(arrowstyle="-", color=COLORS["arrow"],
lw=0.3, alpha=0.3), zorder=1)
if config.n_heads > 8:
ax.text(5, head_y - 0.6, f"... ({config.n_heads} heads total)",
ha="center", fontsize=9, color=COLORS["text_dim"])
# Attention computation
draw_rounded_box(ax, 5, 2.5, 6, 0.6, COLORS["card_bg"],
"Scaled Dot-Product: softmax(QK^T / √d_k) Γ— V",
fontsize=10)
for h in range(n_show):
hx = 1 + h * head_spacing
draw_arrow(ax, hx, head_y - 0.3, 5, 2.85)
# Concatenate
draw_arrow(ax, 5, 2.15, 5, 1.75)
draw_rounded_box(ax, 5, 1.5, 4, 0.5, COLORS["accent1"],
"Concat β†’ W_O projection", text_color="#000", fontsize=10)
# Output
draw_arrow(ax, 5, 1.2, 5, 0.8)
draw_rounded_box(ax, 5, 0.5, 4, 0.5, COLORS["ffn"],
f"Output: [B, T, {config.d_model}]", text_color="#000", fontsize=9)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=200, bbox_inches="tight",
facecolor=COLORS["bg"], edgecolor="none")
print(f"Saved attention diagram: {save_path}")
return fig
# ═══════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Visualize GPT-300M Architecture")
parser.add_argument("--output", type=str, default="./viz",
help="Output directory for images")
args = parser.parse_args()
import os
os.makedirs(args.output, exist_ok=True)
config = gpt_300m()
print(f"Generating visualizations for GPT-300M ({config.total_params_estimate:,} params)...")
draw_full_architecture(config, os.path.join(args.output, "architecture.png"))
draw_parameter_chart(config, os.path.join(args.output, "parameters.png"))
draw_attention_heads(config, os.path.join(args.output, "attention.png"))
print("Done! All visualizations saved.")