|
|
""" |
|
|
GPT-300M Neural Network Visualizer |
|
|
==================================== |
|
|
Generates detailed architectural diagrams of the GPT-300M model |
|
|
using matplotlib, showing: |
|
|
- Full model architecture flow |
|
|
- Detailed transformer block internals |
|
|
- Attention head visualization |
|
|
- Parameter distribution charts |
|
|
|
|
|
Usage: |
|
|
python visualize_nn.py |
|
|
python visualize_nn.py --output architecture.png |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.patches as patches |
|
|
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch |
|
|
import numpy as np |
|
|
|
|
|
from config import GPT300MConfig, gpt_300m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
COLORS = { |
|
|
"bg": "#0D1117", |
|
|
"text": "#E6EDF3", |
|
|
"text_dim": "#8B949E", |
|
|
"embed": "#58A6FF", |
|
|
"attn": "#F78166", |
|
|
"ffn": "#7EE787", |
|
|
"norm": "#D2A8FF", |
|
|
"residual": "#FFA657", |
|
|
"output": "#FF7B72", |
|
|
"arrow": "#484F58", |
|
|
"highlight": "#1F6FEB", |
|
|
"border": "#30363D", |
|
|
"card_bg": "#161B22", |
|
|
"accent1": "#79C0FF", |
|
|
"accent2": "#BB9AF7", |
|
|
} |
|
|
|
|
|
|
|
|
def draw_rounded_box(ax, x, y, w, h, color, label, fontsize=10, |
|
|
text_color=None, alpha=0.9, sublabel=None): |
|
|
"""Draw a rounded rectangle with label.""" |
|
|
box = FancyBboxPatch( |
|
|
(x - w/2, y - h/2), w, h, |
|
|
boxstyle="round,pad=0.1", |
|
|
facecolor=color, |
|
|
edgecolor="white", |
|
|
linewidth=0.5, |
|
|
alpha=alpha, |
|
|
zorder=3, |
|
|
) |
|
|
ax.add_patch(box) |
|
|
ax.text( |
|
|
x, y + (0.15 if sublabel else 0), |
|
|
label, |
|
|
ha="center", va="center", |
|
|
fontsize=fontsize, |
|
|
fontweight="bold", |
|
|
color=text_color or COLORS["text"], |
|
|
zorder=4, |
|
|
) |
|
|
if sublabel: |
|
|
ax.text( |
|
|
x, y - 0.25, |
|
|
sublabel, |
|
|
ha="center", va="center", |
|
|
fontsize=fontsize - 2, |
|
|
color=COLORS["text_dim"], |
|
|
zorder=4, |
|
|
) |
|
|
|
|
|
|
|
|
def draw_arrow(ax, x1, y1, x2, y2, color=None): |
|
|
"""Draw an arrow between two points.""" |
|
|
ax.annotate( |
|
|
"", |
|
|
xy=(x2, y2), xytext=(x1, y1), |
|
|
arrowprops=dict( |
|
|
arrowstyle="->", |
|
|
color=color or COLORS["arrow"], |
|
|
lw=1.5, |
|
|
connectionstyle="arc3,rad=0", |
|
|
), |
|
|
zorder=2, |
|
|
) |
|
|
|
|
|
|
|
|
def draw_residual_connection(ax, x_start, y_start, x_end, y_end, offset=1.8): |
|
|
"""Draw a residual/skip connection arc.""" |
|
|
ax.annotate( |
|
|
"", |
|
|
xy=(x_end, y_end), xytext=(x_start, y_start), |
|
|
arrowprops=dict( |
|
|
arrowstyle="->", |
|
|
color=COLORS["residual"], |
|
|
lw=1.2, |
|
|
linestyle="--", |
|
|
connectionstyle=f"arc3,rad=0.3", |
|
|
), |
|
|
zorder=1, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_full_architecture(config: GPT300MConfig, save_path: str = None): |
|
|
"""Draw the complete GPT-300M architecture.""" |
|
|
fig, ax = plt.subplots(1, 1, figsize=(14, 24), facecolor=COLORS["bg"]) |
|
|
ax.set_facecolor(COLORS["bg"]) |
|
|
ax.set_xlim(-4, 4) |
|
|
ax.set_ylim(-1, 22) |
|
|
ax.axis("off") |
|
|
|
|
|
|
|
|
ax.text(0, 21.5, "GPT-300M Architecture", ha="center", va="center", |
|
|
fontsize=22, fontweight="bold", color=COLORS["text"], |
|
|
fontfamily="monospace") |
|
|
ax.text(0, 21.0, |
|
|
f"{config.total_params_estimate:,} parameters β’ " |
|
|
f"{config.n_layers} layers β’ " |
|
|
f"{config.n_heads} heads β’ " |
|
|
f"d={config.d_model}", |
|
|
ha="center", va="center", fontsize=10, color=COLORS["text_dim"], |
|
|
fontfamily="monospace") |
|
|
|
|
|
y = 19.5 |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["card_bg"], "Input Token IDs", |
|
|
sublabel=f"[batch, seq_len]", fontsize=11) |
|
|
y -= 1.1 |
|
|
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y, 3.5, 0.7, COLORS["embed"], |
|
|
"Token Embedding", text_color="#000", |
|
|
sublabel=f"{config.vocab_size:,} Γ {config.d_model}") |
|
|
y -= 1.1 |
|
|
draw_arrow(ax, 0, y + 0.8, 0, y + 0.4) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["accent2"], |
|
|
"Rotary Position Embeddings (RoPE)", |
|
|
text_color="#000", fontsize=9, |
|
|
sublabel=f"ΞΈ = {config.rope_theta:.0f}") |
|
|
y -= 1.0 |
|
|
draw_arrow(ax, 0, y + 0.7, 0, y + 0.4) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y, 2.5, 0.5, COLORS["border"], |
|
|
f"Dropout (p={config.dropout})", fontsize=9) |
|
|
y -= 1.0 |
|
|
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35) |
|
|
|
|
|
|
|
|
block_height = 3.2 |
|
|
|
|
|
|
|
|
block_y_start = y |
|
|
block_y_end = y - block_height |
|
|
|
|
|
|
|
|
block_box = FancyBboxPatch( |
|
|
(-3.3, block_y_end - 0.1), 6.6, block_height + 0.2, |
|
|
boxstyle="round,pad=0.15", |
|
|
facecolor=COLORS["card_bg"], |
|
|
edgecolor=COLORS["highlight"], |
|
|
linewidth=1.5, |
|
|
alpha=0.8, |
|
|
zorder=1, |
|
|
) |
|
|
ax.add_patch(block_box) |
|
|
ax.text(-3.0, block_y_start + 0.05, |
|
|
f"Transformer Block Γ {config.n_layers}", |
|
|
fontsize=10, fontweight="bold", color=COLORS["highlight"], |
|
|
fontfamily="monospace", zorder=5) |
|
|
|
|
|
|
|
|
by = block_y_start - 0.4 |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"], |
|
|
"RMSNorm", text_color="#000", fontsize=9) |
|
|
by -= 0.7 |
|
|
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["attn"], |
|
|
"Multi-Head Attention", text_color="#000", fontsize=10, |
|
|
sublabel=f"{config.n_heads} heads Γ {config.head_dim}d") |
|
|
|
|
|
draw_residual_connection(ax, -1.6, block_y_start - 0.2, -1.6, by) |
|
|
ax.text(-2.5, by + 0.3, "β residual", fontsize=7, |
|
|
color=COLORS["residual"], ha="center") |
|
|
|
|
|
by -= 0.8 |
|
|
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, by, 2.8, 0.45, COLORS["norm"], |
|
|
"RMSNorm", text_color="#000", fontsize=9) |
|
|
by -= 0.7 |
|
|
draw_arrow(ax, 0, by + 0.5, 0, by + 0.25) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, by, 2.8, 0.7, COLORS["ffn"], |
|
|
"Feed-Forward Network", text_color="#000", fontsize=10, |
|
|
sublabel=f"{config.d_model} β {config.d_ff} β {config.d_model}") |
|
|
|
|
|
draw_residual_connection(ax, 1.6, by + 1.5, 1.6, by) |
|
|
ax.text(2.5, by + 0.7, "β residual", fontsize=7, |
|
|
color=COLORS["residual"], ha="center") |
|
|
|
|
|
y = block_y_end - 0.4 |
|
|
|
|
|
|
|
|
draw_arrow(ax, 0, y + 0.2, 0, y - 0.1) |
|
|
ax.text(0, y - 0.3, f"Γ {config.n_layers} layers", ha="center", |
|
|
fontsize=11, fontweight="bold", color=COLORS["text_dim"], |
|
|
fontfamily="monospace", |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor=COLORS["card_bg"], |
|
|
edgecolor=COLORS["border"])) |
|
|
y -= 0.9 |
|
|
draw_arrow(ax, 0, y + 0.3, 0, y + 0.05) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y - 0.2, 3.5, 0.5, COLORS["norm"], |
|
|
"Final RMSNorm", text_color="#000", fontsize=10) |
|
|
y -= 1.0 |
|
|
draw_arrow(ax, 0, y + 0.5, 0, y + 0.2) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y - 0.1, 3.5, 0.7, COLORS["output"], |
|
|
"Linear (LM Head)", text_color="#000", fontsize=11, |
|
|
sublabel=f"{config.d_model} β {config.vocab_size:,} (weight-tied)") |
|
|
y -= 1.1 |
|
|
draw_arrow(ax, 0, y + 0.7, 0, y + 0.35) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 0, y, 3.5, 0.6, COLORS["card_bg"], |
|
|
"Softmax β Next Token Probabilities", fontsize=10, |
|
|
sublabel=f"[batch, seq_len, {config.vocab_size:,}]") |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
fig.savefig(save_path, dpi=200, bbox_inches="tight", |
|
|
facecolor=COLORS["bg"], edgecolor="none") |
|
|
print(f"Saved architecture diagram: {save_path}") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_parameter_chart(config: GPT300MConfig, save_path: str = None): |
|
|
"""Draw a parameter distribution breakdown.""" |
|
|
fig, axes = plt.subplots(1, 2, figsize=(16, 7), facecolor=COLORS["bg"]) |
|
|
|
|
|
|
|
|
emb_params = config.vocab_size * config.d_model |
|
|
attn_params = 4 * config.d_model * config.d_model * config.n_layers |
|
|
ffn_params = 2 * config.d_model * config.d_ff * config.n_layers |
|
|
norm_params = 2 * config.d_model * config.n_layers + config.d_model |
|
|
total = emb_params + attn_params + ffn_params + norm_params |
|
|
|
|
|
|
|
|
ax = axes[0] |
|
|
ax.set_facecolor(COLORS["bg"]) |
|
|
labels = ["Token\nEmbedding", "Attention\nLayers", "Feed-Forward\nLayers", "LayerNorm"] |
|
|
sizes = [emb_params, attn_params, ffn_params, norm_params] |
|
|
colors = [COLORS["embed"], COLORS["attn"], COLORS["ffn"], COLORS["norm"]] |
|
|
|
|
|
wedges, texts, autotexts = ax.pie( |
|
|
sizes, labels=None, autopct=lambda p: f"{p:.1f}%", |
|
|
colors=colors, startangle=90, pctdistance=0.7, |
|
|
wedgeprops=dict(width=0.5, edgecolor=COLORS["bg"], linewidth=2), |
|
|
textprops=dict(color=COLORS["text"], fontsize=10), |
|
|
) |
|
|
for at in autotexts: |
|
|
at.set_fontweight("bold") |
|
|
at.set_color("#000") |
|
|
|
|
|
|
|
|
legend_labels = [ |
|
|
f"{l}\n({s/1e6:.1f}M)" for l, s in zip( |
|
|
["Token Embedding", "Attention", "Feed-Forward", "LayerNorm"], |
|
|
sizes |
|
|
) |
|
|
] |
|
|
ax.legend( |
|
|
wedges, legend_labels, loc="center left", bbox_to_anchor=(1.05, 0.5), |
|
|
fontsize=9, frameon=False, labelcolor=COLORS["text"], |
|
|
) |
|
|
ax.set_title("Parameter Distribution", fontsize=14, fontweight="bold", |
|
|
color=COLORS["text"], pad=15) |
|
|
|
|
|
|
|
|
ax = axes[1] |
|
|
ax.set_facecolor(COLORS["bg"]) |
|
|
|
|
|
layer_attn = 4 * config.d_model * config.d_model |
|
|
layer_ffn = 2 * config.d_model * config.d_ff |
|
|
layer_norm = 2 * config.d_model |
|
|
|
|
|
layers = range(1, config.n_layers + 1) |
|
|
bar_width = 0.8 |
|
|
|
|
|
ax.bar(layers, [layer_attn / 1e6] * config.n_layers, bar_width, |
|
|
label="Attention", color=COLORS["attn"], alpha=0.9) |
|
|
ax.bar(layers, [layer_ffn / 1e6] * config.n_layers, bar_width, |
|
|
bottom=[layer_attn / 1e6] * config.n_layers, |
|
|
label="Feed-Forward", color=COLORS["ffn"], alpha=0.9) |
|
|
ax.bar(layers, [layer_norm / 1e6] * config.n_layers, bar_width, |
|
|
bottom=[(layer_attn + layer_ffn) / 1e6] * config.n_layers, |
|
|
label="Norm", color=COLORS["norm"], alpha=0.9) |
|
|
|
|
|
ax.set_xlabel("Layer", fontsize=11, color=COLORS["text"]) |
|
|
ax.set_ylabel("Parameters (M)", fontsize=11, color=COLORS["text"]) |
|
|
ax.set_title("Parameters Per Layer", fontsize=14, fontweight="bold", |
|
|
color=COLORS["text"], pad=15) |
|
|
ax.legend(fontsize=9, frameon=False, labelcolor=COLORS["text"]) |
|
|
ax.tick_params(colors=COLORS["text_dim"]) |
|
|
ax.spines["bottom"].set_color(COLORS["border"]) |
|
|
ax.spines["left"].set_color(COLORS["border"]) |
|
|
ax.spines["top"].set_visible(False) |
|
|
ax.spines["right"].set_visible(False) |
|
|
|
|
|
|
|
|
fig.suptitle( |
|
|
f"GPT-300M β’ {total:,} Total Parameters", |
|
|
fontsize=16, fontweight="bold", color=COLORS["text"], |
|
|
fontfamily="monospace", y=1.02, |
|
|
) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
fig.savefig(save_path, dpi=200, bbox_inches="tight", |
|
|
facecolor=COLORS["bg"], edgecolor="none") |
|
|
print(f"Saved parameter chart: {save_path}") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_attention_heads(config: GPT300MConfig, save_path: str = None): |
|
|
"""Visualize the multi-head attention mechanism.""" |
|
|
fig, ax = plt.subplots(1, 1, figsize=(14, 10), facecolor=COLORS["bg"]) |
|
|
ax.set_facecolor(COLORS["bg"]) |
|
|
ax.set_xlim(-1, 11) |
|
|
ax.set_ylim(-1, 8) |
|
|
ax.axis("off") |
|
|
|
|
|
ax.text(5, 7.5, "Multi-Head Self-Attention", ha="center", |
|
|
fontsize=18, fontweight="bold", color=COLORS["text"], |
|
|
fontfamily="monospace") |
|
|
ax.text(5, 7.0, |
|
|
f"{config.n_heads} heads Γ {config.head_dim}d per head = {config.d_model}d total", |
|
|
ha="center", fontsize=10, color=COLORS["text_dim"]) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 5, 6.2, 4, 0.5, COLORS["embed"], |
|
|
f"Input: [B, T, {config.d_model}]", text_color="#000", fontsize=9) |
|
|
|
|
|
|
|
|
for i, (name, color) in enumerate(zip(["Q", "K", "V"], |
|
|
["#FF6B6B", "#4ECDC4", "#45B7D1"])): |
|
|
x = 2 + i * 3 |
|
|
draw_arrow(ax, 5, 5.9, x, 5.4) |
|
|
draw_rounded_box(ax, x, 5.1, 1.8, 0.5, color, |
|
|
f"W_{name}", text_color="#000", fontsize=10, |
|
|
sublabel=f"{config.d_model}Γ{config.d_model}") |
|
|
|
|
|
|
|
|
head_y = 3.8 |
|
|
n_show = min(config.n_heads, 8) |
|
|
head_spacing = 9.0 / n_show |
|
|
|
|
|
for h in range(n_show): |
|
|
hx = 1 + h * head_spacing |
|
|
|
|
|
box = FancyBboxPatch( |
|
|
(hx - 0.4, head_y - 0.3), 0.8, 0.6, |
|
|
boxstyle="round,pad=0.05", |
|
|
facecolor=COLORS["attn"], |
|
|
edgecolor="white", |
|
|
linewidth=0.5, |
|
|
alpha=0.8, |
|
|
zorder=3, |
|
|
) |
|
|
ax.add_patch(box) |
|
|
ax.text(hx, head_y, f"H{h+1}", ha="center", va="center", |
|
|
fontsize=8, fontweight="bold", color="#000", zorder=4) |
|
|
|
|
|
|
|
|
for qi, qx in enumerate([2, 5, 8]): |
|
|
ax.annotate("", xy=(hx, head_y + 0.3), xytext=(qx, 4.8), |
|
|
arrowprops=dict(arrowstyle="-", color=COLORS["arrow"], |
|
|
lw=0.3, alpha=0.3), zorder=1) |
|
|
|
|
|
if config.n_heads > 8: |
|
|
ax.text(5, head_y - 0.6, f"... ({config.n_heads} heads total)", |
|
|
ha="center", fontsize=9, color=COLORS["text_dim"]) |
|
|
|
|
|
|
|
|
draw_rounded_box(ax, 5, 2.5, 6, 0.6, COLORS["card_bg"], |
|
|
"Scaled Dot-Product: softmax(QK^T / βd_k) Γ V", |
|
|
fontsize=10) |
|
|
for h in range(n_show): |
|
|
hx = 1 + h * head_spacing |
|
|
draw_arrow(ax, hx, head_y - 0.3, 5, 2.85) |
|
|
|
|
|
|
|
|
draw_arrow(ax, 5, 2.15, 5, 1.75) |
|
|
draw_rounded_box(ax, 5, 1.5, 4, 0.5, COLORS["accent1"], |
|
|
"Concat β W_O projection", text_color="#000", fontsize=10) |
|
|
|
|
|
|
|
|
draw_arrow(ax, 5, 1.2, 5, 0.8) |
|
|
draw_rounded_box(ax, 5, 0.5, 4, 0.5, COLORS["ffn"], |
|
|
f"Output: [B, T, {config.d_model}]", text_color="#000", fontsize=9) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
fig.savefig(save_path, dpi=200, bbox_inches="tight", |
|
|
facecolor=COLORS["bg"], edgecolor="none") |
|
|
print(f"Saved attention diagram: {save_path}") |
|
|
|
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Visualize GPT-300M Architecture") |
|
|
parser.add_argument("--output", type=str, default="./viz", |
|
|
help="Output directory for images") |
|
|
args = parser.parse_args() |
|
|
|
|
|
import os |
|
|
os.makedirs(args.output, exist_ok=True) |
|
|
|
|
|
config = gpt_300m() |
|
|
print(f"Generating visualizations for GPT-300M ({config.total_params_estimate:,} params)...") |
|
|
|
|
|
draw_full_architecture(config, os.path.join(args.output, "architecture.png")) |
|
|
draw_parameter_chart(config, os.path.join(args.output, "parameters.png")) |
|
|
draw_attention_heads(config, os.path.join(args.output, "attention.png")) |
|
|
|
|
|
print("Done! All visualizations saved.") |
|
|
|