tiny-torch-viz / static /index.html
Adrian Gabriel
Adapt index.html
4da3869
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>TinyTorch Visualizer</title>
<!-- KaTeX for LaTeX -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.css">
<script src="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.js"></script>
<!-- Markdown -->
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
<!-- PDF.js -->
<script src="https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.min.js"></script>
<style>
* { box-sizing: border-box; }
body {
margin: 0;
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
background: #0a0e27;
color: #e2e8f0;
}
#app { display: flex; height: 100vh; overflow: hidden; }
#canvas {
flex: 1 1 auto;
min-width: 320px;
position: relative;
overflow: auto;
background:
linear-gradient(#1a1f3a 1px, transparent 1px),
linear-gradient(90deg, #1a1f3a 1px, transparent 1px);
background-size: 30px 30px;
padding: 40px;
}
#canvas-content {
position: relative;
transform-origin: 0 0;
min-width: 4000px;
min-height: 2400px;
padding-top: 72px;
}
.v-resizer {
width: 6px;
min-width: 6px;
flex-shrink: 0;
background: #1e293b;
cursor: col-resize;
z-index: 100;
transition: background 0.2s;
}
.v-resizer:hover {
background: #3b82f6;
}
#editor {
flex: 0 0 460px;
min-width: 340px;
background: #0f1419;
border-left: 2px solid #1e293b;
display: flex;
flex-direction: column;
padding: 20px;
gap: 12px;
}
#editor h2 {
margin: 0;
font-size: 13px;
text-transform: uppercase;
letter-spacing: 0.1em;
color: #64748b;
display: flex;
align-items: center;
gap: 8px;
}
.status-indicator {
width: 8px; height: 8px; border-radius: 50%;
background: #22c55e;
}
.status-indicator.disconnected { background: #ef4444; }
.status-indicator.connecting { background: #f59e0b; }
#code {
flex: 1;
background: #1a1f3a;
border: 1px solid #334155;
border-radius: 8px;
padding: 12px;
color: #e2e8f0;
font-family: 'Courier New', monospace;
font-size: 14px;
line-height: 1.5;
resize: none;
}
#run {
padding: 10px 24px;
background: #3b82f6;
color: white;
border: none;
border-radius: 6px;
cursor: pointer;
font-weight: 600;
}
#run:hover { background: #2563eb; }
#run:disabled { background: #475569; cursor: not-allowed; }
#error { color: #ef4444; font-size: 12px; min-height: 20px; font-family: monospace; white-space: pre-wrap; }
/* Console output area */
#console-container {
margin-top: 12px;
border: 1px solid #334155;
border-radius: 6px;
background: #0a0f1a;
overflow: hidden;
display: flex;
flex-direction: column;
min-height: 60px;
height: 150px;
flex-shrink: 0;
}
#console-resizer {
height: 6px;
background: transparent;
cursor: ns-resize;
position: relative;
flex-shrink: 0;
}
#console-resizer::after {
content: '';
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
width: 40px;
height: 3px;
background: #475569;
border-radius: 2px;
transition: background 0.2s;
}
#console-resizer:hover::after {
background: #64748b;
}
#console-header {
display: flex;
justify-content: space-between;
align-items: center;
padding: 6px 10px;
background: #1e293b;
border-bottom: 1px solid #334155;
font-size: 11px;
font-weight: 600;
color: #94a3b8;
letter-spacing: 0.5px;
}
#console-clear {
background: none;
border: none;
color: #64748b;
cursor: pointer;
font-size: 12px;
padding: 2px 6px;
border-radius: 3px;
}
#console-clear:hover {
background: #334155;
color: #f1f5f9;
}
#console-output {
flex: 1;
overflow-y: auto;
padding: 8px 10px;
font-family: 'Fira Code', 'JetBrains Mono', 'Consolas', monospace;
font-size: 12px;
line-height: 1.5;
color: #e2e8f0;
white-space: pre-wrap;
word-break: break-word;
}
#console-output .console-line {
margin: 2px 0;
}
#console-output .console-line.error {
color: #f87171;
}
#console-output .console-line.success {
color: #4ade80;
}
#console-output .console-line.info {
color: #60a5fa;
}
#pdfSidebar {
flex: 1 1 340px;
min-width: 260px;
background: #020617;
border-left: 2px solid #1e293b;
display: flex;
flex-direction: column;
}
#pdfSidebarHeader {
padding: 8px 10px;
border-bottom: 1px solid #1f2937;
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
}
.pdf-title { font-size: 11px; text-transform: uppercase; color: #9ca3af; }
.pdf-file-btn {
position: relative;
border-radius: 999px;
background: #111827;
color: #e5e7eb;
font-size: 11px;
padding: 5px 10px;
cursor: pointer;
}
.pdf-file-btn input { position: absolute; inset: 0; opacity: 0; cursor: pointer; }
.pdf-controls, .pdf-zoom { display: inline-flex; align-items: center; gap: 4px; }
.pdf-controls { margin-left: auto; }
.pdf-btn {
border: none;
border-radius: 999px;
background: #020617;
color: #e5e7eb;
padding: 3px 7px;
font-size: 11px;
cursor: pointer;
border: 1px solid #1f2937;
}
#pdfPageInfo, #pdfZoomValue { font-size: 11px; color: #9ca3af; min-width: 52px; text-align: center; }
#pdfViewerWrapper { flex: 1; overflow: auto; padding: 8px 10px 16px 10px; }
#pdfCanvas { display: none; background: #020617; border-radius: 8px; }
#pdfPlaceholder { font-size: 11px; color: #6b7280; padding: 10px; text-align: center; }
/* GROUPS */
.group-container {
position: absolute;
z-index: 2;
/* Debug: add visible background to see actual size */
/* background: rgba(255, 0, 0, 0.1); */
}
.layout-grid {
display: inline-grid;
}
.group-label {
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.15em;
color: #64748b;
margin-bottom: 8px;
font-weight: 600;
cursor: grab;
user-select: none;
white-space: nowrap; /* Prevent wrapping */
}
.group-label:active { cursor: grabbing; }
.group-selected { box-shadow: 0 0 0 2px #3b82f6; border-radius: 8px; }
/* MATRIX CARDS */
.matrix-card {
background: #1e293b;
border: 2px solid #334155;
border-radius: 10px;
padding: 10px 12px;
box-shadow: 0 4px 20px rgba(0,0,0,0.4);
display: inline-block;
min-width: 100px;
}
.matrix-title {
font-size: 11px;
font-weight: 700;
color: #3b82f6;
margin-bottom: 8px;
display: flex;
justify-content: space-between;
}
.matrix-shape { color: #64748b; font-weight: 500; }
.matrix-table { border-collapse: collapse; font-size: 13px; }
.matrix-table td {
border: 1px solid #334155;
padding: 6px 8px;
text-align: right;
background: #0f1419;
min-width: 50px;
transition: background-color 0.15s ease, box-shadow 0.15s ease;
}
/* Highlight styles for matmul hover */
.matrix-table td.highlight-row {
background: rgba(59, 130, 246, 0.3);
box-shadow: inset 0 0 0 1px #3b82f6;
}
.matrix-table td.highlight-col {
background: rgba(16, 185, 129, 0.3);
box-shadow: inset 0 0 0 1px #10b981;
}
.matrix-table td.highlight-cell {
background: rgba(251, 191, 36, 0.4);
box-shadow: inset 0 0 0 2px #fbbf24;
}
/* Tensor flow highlighting - same tensor used in multiple places */
.matrix-card.tensor-flow-highlight {
box-shadow: 0 0 0 3px #f59e0b, 0 0 20px rgba(245, 158, 11, 0.5);
border-color: #f59e0b;
transition: box-shadow 0.2s ease, border-color 0.2s ease;
}
.matrix-card[data-tensor-id] {
cursor: pointer;
}
.retro-mode .matrix-card.tensor-flow-highlight {
box-shadow: 0 0 0 3px #00ff00, 0 0 25px rgba(0, 255, 0, 0.6);
border-color: #00ff00;
}
/* Retro mode highlights */
body.retro-mode .matrix-table td.highlight-row {
background: rgba(0, 255, 255, 0.2);
box-shadow: inset 0 0 0 1px #00ffff, 0 0 8px rgba(0, 255, 255, 0.3);
}
body.retro-mode .matrix-table td.highlight-col {
background: rgba(255, 0, 255, 0.2);
box-shadow: inset 0 0 0 1px #ff00ff, 0 0 8px rgba(255, 0, 255, 0.3);
}
body.retro-mode .matrix-table td.highlight-cell {
background: rgba(255, 255, 0, 0.3);
box-shadow: inset 0 0 0 2px #ffff00, 0 0 12px rgba(255, 255, 0, 0.5);
}
.layout-grid { gap: 8px; } /* display: inline-grid set above */
.layout-binary {
grid-template-columns: auto auto;
grid-template-rows: auto auto;
gap: 8px;
}
.layout-unary { grid-template-columns: auto; }
/* Matmul layout: a on left, b on top-right, result on bottom-right */
.pos-left { grid-row: 2; grid-column: 1; align-self: start; }
.pos-top { grid-row: 1; grid-column: 2; align-self: end; }
.pos-result { grid-row: 2; grid-column: 2; align-self: start; }
/* Element-wise operation layout: inputs side by side, result below left */
.layout-elementwise {
display: inline-grid;
grid-template-columns: auto auto auto;
grid-template-rows: auto auto;
gap: 10px;
align-items: center;
}
.elem-left { grid-row: 1; grid-column: 1; }
.elem-op { grid-row: 1; grid-column: 2; align-self: center; justify-self: center; }
.elem-right { grid-row: 1; grid-column: 3; }
.elem-result { grid-row: 2; grid-column: 1; }
.op-symbol-large {
font-size: 28px;
font-weight: 300;
color: #64748b;
padding: 0 8px;
user-select: none;
}
/* Linear layer layout: matmul-style with X on top, W on left, result at intersection */
.layout-linear {
display: inline-grid;
grid-template-columns: auto auto;
grid-template-rows: auto auto;
gap: 8px;
}
.linear-empty { grid-row: 1; grid-column: 1; } /* Empty top-left cell */
.linear-input { grid-row: 1; grid-column: 2; align-self: end; } /* X on top, aligned to bottom */
.linear-weight { grid-row: 2; grid-column: 1; align-self: start; } /* W on left, aligned to top */
.linear-output { grid-row: 2; grid-column: 2; align-self: start; } /* Result aligned to top with W */
/* TOOLBAR */
#toolbar {
position: sticky;
top: 0; left: 0;
z-index: 100;
display: inline-flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
padding: 8px 12px;
margin-bottom: 12px;
border-radius: 999px;
background: rgba(15,23,42,0.95);
box-shadow: 0 10px 30px rgba(0,0,0,0.6);
}
.tool-btn {
border: none;
border-radius: 999px;
padding: 6px 14px;
font-size: 12px;
font-weight: 600;
cursor: pointer;
background: #111827;
color: #e5e7eb;
}
.tool-btn:hover { background: #1f2937; }
.tool-btn.active { background: #3b82f6; color: white; }
.tool-group {
display: inline-flex;
align-items: center;
gap: 6px;
margin-left: 8px;
padding-left: 8px;
border-left: 1px solid #1f2937;
}
.tool-label { font-size: 11px; text-transform: uppercase; color: #9ca3af; }
.tool-select {
background: #020617;
color: #e5e7eb;
border-radius: 999px;
border: 1px solid #1f2937;
padding: 3px 8px;
font-size: 11px;
}
.tool-toggle { display: inline-flex; align-items: center; gap: 4px; font-size: 11px; color: #9ca3af; }
/* LAYER BOXES */
.layer-box {
position: absolute;
border-radius: 12px;
border: 2px solid;
z-index: 1;
}
.layer-box-label {
position: absolute;
top: 6px; left: 12px; right: 12px;
font-size: 10px;
text-transform: uppercase;
letter-spacing: 0.12em;
font-weight: 600;
cursor: grab;
user-select: none;
}
.layer-box-label:active { cursor: grabbing; }
.layer-box-selected { box-shadow: 0 0 0 2px #f97316; }
/* NOTES */
.note-container {
position: absolute;
max-width: 300px;
background: rgba(15,23,42,0.97);
border-radius: 8px;
border: 1px solid #38bdf8;
padding: 8px 10px;
font-size: 12px;
z-index: 4;
cursor: grab;
}
.note-header { font-size: 10px; text-transform: uppercase; color: #7dd3fc; margin-bottom: 4px; cursor: grab; }
.note-body { font-size: 12px; line-height: 1.45; color: #e5e7eb; }
.note-modal-backdrop {
position: fixed; inset: 0; display: none;
background: rgba(0,0,0,0.5); z-index: 200;
align-items: center; justify-content: center;
}
.note-modal-backdrop.show { display: flex; }
.note-modal {
width: min(640px, 90vw);
background: #020617;
border-radius: 12px;
padding: 20px;
}
.note-modal-title { font-size: 14px; font-weight: 600; color: #e5e7eb; margin-bottom: 12px; }
#note-modal-textarea {
width: 100%; min-height: 200px;
border-radius: 8px; border: 1px solid #1f2937;
background: #0f1419; color: #e5e7eb;
padding: 12px; font-family: monospace; font-size: 13px;
margin-bottom: 12px;
}
.note-modal-actions { display: flex; gap: 10px; justify-content: flex-end; }
.note-modal-btn { border: none; border-radius: 6px; padding: 8px 16px; font-size: 13px; cursor: pointer; }
#note-modal-cancel { background: #374151; color: #e5e7eb; }
#note-modal-delete { background: #dc2626; color: white; }
#note-modal-save { background: #3b82f6; color: white; }
/* ==================== RETRO/CRT MODE ==================== */
:root {
--retro-glass-bg: rgba(15, 15, 20, 0.65);
--retro-glass-border: rgba(255, 255, 255, 0.08);
--retro-neon-green: #00ff41;
--retro-neon-blue: #00f3ff;
--retro-text-primary: #e0e0e0;
}
/* Scanline background for body */
body.retro-mode {
background-color: #050505;
background-image:
linear-gradient(rgba(18, 16, 16, 0) 50%, rgba(0, 0, 0, 0.15) 50%),
linear-gradient(90deg, rgba(255,0,0,0.03), rgba(0,255,0,0.02), rgba(0,0,255,0.03));
background-size: 100% 2px, 3px 100%;
}
/* Canvas gets CRT overlay */
body.retro-mode #canvas {
background:
linear-gradient(rgba(18, 16, 16, 0) 50%, rgba(0, 0, 0, 0.08) 50%),
linear-gradient(#0a0f0a 1px, transparent 1px),
linear-gradient(90deg, #0a0f0a 1px, transparent 1px);
background-size: 100% 2px, 30px 30px, 30px 30px;
}
/* Glassmorphism for editor panel */
body.retro-mode #editor {
background: var(--retro-glass-bg);
backdrop-filter: blur(12px);
-webkit-backdrop-filter: blur(12px);
border-left: 1px solid var(--retro-glass-border);
box-shadow: -8px 0 32px rgba(0, 0, 0, 0.4);
}
/* Glassmorphism for toolbar */
body.retro-mode #toolbar {
background: var(--retro-glass-bg);
backdrop-filter: blur(12px);
-webkit-backdrop-filter: blur(12px);
border: 1px solid var(--retro-glass-border);
box-shadow: 0 4px 24px rgba(0, 0, 0, 0.4);
}
/* Matrix cards get glass effect */
body.retro-mode .matrix-card {
background: var(--retro-glass-bg);
backdrop-filter: blur(8px);
-webkit-backdrop-filter: blur(8px);
border: 1px solid var(--retro-glass-border);
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.3);
}
/* Group containers */
body.retro-mode .group-container {
background: rgba(10, 15, 10, 0.5);
border: 1px solid rgba(0, 255, 65, 0.15);
box-shadow: 0 0 20px rgba(0, 255, 65, 0.05);
}
/* Neon green for tensor values */
body.retro-mode .matrix-cell {
color: var(--retro-neon-green);
text-shadow: 0 0 3px rgba(0, 255, 65, 0.5);
font-family: 'Fira Code', 'JetBrains Mono', 'Courier New', monospace;
}
/* Matrix labels glow blue */
body.retro-mode .matrix-label {
color: var(--retro-neon-blue);
text-shadow: 0 0 4px rgba(0, 243, 255, 0.4);
}
/* Shape badges */
body.retro-mode .shape-badge {
background: rgba(0, 255, 65, 0.15);
color: var(--retro-neon-green);
border: 1px solid rgba(0, 255, 65, 0.3);
text-shadow: 0 0 2px rgba(0, 255, 65, 0.5);
}
/* Code editor styling */
body.retro-mode #code {
background: rgba(5, 10, 5, 0.8);
color: var(--retro-neon-green);
text-shadow: 0 0 1px rgba(0, 255, 65, 0.3);
font-family: 'Fira Code', 'JetBrains Mono', 'Courier New', monospace;
border: 1px solid rgba(0, 255, 65, 0.2);
}
/* Run button gets gradient glow */
body.retro-mode #run {
background: linear-gradient(135deg, rgba(0, 100, 30, 0.8), rgba(0, 60, 20, 0.9));
border: 1px solid rgba(0, 255, 65, 0.4);
box-shadow: 0 0 15px rgba(0, 255, 65, 0.2);
text-shadow: 0 0 4px rgba(0, 255, 65, 0.5);
}
body.retro-mode #run:hover {
background: linear-gradient(135deg, rgba(0, 130, 40, 0.9), rgba(0, 80, 25, 0.95));
box-shadow: 0 0 25px rgba(0, 255, 65, 0.4);
}
/* Buttons get subtle glow */
body.retro-mode .tool-btn {
background: rgba(20, 30, 20, 0.7);
border: 1px solid rgba(0, 255, 65, 0.2);
color: var(--retro-text-primary);
}
body.retro-mode .tool-btn:hover {
background: rgba(0, 255, 65, 0.1);
border-color: rgba(0, 255, 65, 0.4);
box-shadow: 0 0 10px rgba(0, 255, 65, 0.2);
}
body.retro-mode .tool-btn.active {
background: rgba(0, 255, 65, 0.2);
border-color: var(--retro-neon-green);
color: var(--retro-neon-green);
box-shadow: 0 0 12px rgba(0, 255, 65, 0.3);
}
/* Layer boxes get neon borders */
body.retro-mode .layer-box {
backdrop-filter: blur(4px);
-webkit-backdrop-filter: blur(4px);
}
/* Op symbols */
body.retro-mode .op-symbol,
body.retro-mode .op-symbol-large {
color: var(--retro-neon-blue);
text-shadow: 0 0 6px rgba(0, 243, 255, 0.6);
}
/* Group labels */
body.retro-mode .group-label {
color: var(--retro-neon-blue);
text-shadow: 0 0 3px rgba(0, 243, 255, 0.4);
font-family: 'Fira Code', 'JetBrains Mono', 'Courier New', monospace;
}
/* Error display */
body.retro-mode #error {
color: #ff4444;
text-shadow: 0 0 4px rgba(255, 68, 68, 0.5);
}
/* Console in retro mode */
body.retro-mode #console-resizer::after {
background: rgba(0, 255, 65, 0.3);
}
body.retro-mode #console-resizer:hover::after {
background: rgba(0, 255, 65, 0.6);
}
body.retro-mode #console-container {
border-color: rgba(0, 255, 65, 0.3);
background: rgba(0, 10, 5, 0.8);
}
body.retro-mode #console-header {
background: rgba(0, 40, 20, 0.6);
border-bottom-color: rgba(0, 255, 65, 0.2);
color: #00ff41;
text-shadow: 0 0 4px rgba(0, 255, 65, 0.5);
}
body.retro-mode #console-output {
color: #00ff41;
text-shadow: 0 0 2px rgba(0, 255, 65, 0.3);
}
body.retro-mode #console-output .console-line.error {
color: #ff4444;
text-shadow: 0 0 4px rgba(255, 68, 68, 0.5);
}
/* Status dot glow */
body.retro-mode .status-dot {
box-shadow: 0 0 8px currentColor;
}
/* Running animation */
@keyframes retroDataFlow {
0% { box-shadow: 0 0 5px rgba(0, 255, 65, 0.3); }
50% { box-shadow: 0 0 20px rgba(0, 255, 65, 0.6), 0 0 30px rgba(0, 243, 255, 0.3); }
100% { box-shadow: 0 0 5px rgba(0, 255, 65, 0.3); }
}
body.retro-mode .group-container {
animation: retroDataFlow 2s ease-in-out infinite;
animation-play-state: paused;
}
body.retro-mode.running .group-container {
animation-play-state: running;
}
</style>
</head>
<body>
<div id="app">
<div id="canvas"></div>
<div class="v-resizer" data-left="canvas"></div>
<div id="editor">
<h2><span class="status-indicator" id="wsStatus"></span> TINYTORCH PYTHON CODE</h2>
<textarea id="code" spellcheck="false">
# TabPFN
import numpy
# training data
X_train = Tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
Y_train = Tensor([1, 0])
X_test = Tensor([[9, 10, 11, 12]])
box("X_train", [X_train, Y_train, X_test], "1")
# Feature Encoder - Feature Embeddings
W_enc = Tensor([[1, 0.5], [0.5, 1], [0.3, 0.7], [0.7, 0.3]])
W_enc_transpose = W_enc.transpose()
b_enc = Tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]])
box("Feature Encoder", W_enc_transpose, "2")
# Feature/group embeddings
E_feat = Tensor([[0.1, 0.0, 0.0, 0.0], [0.0, 0.1, 0.0, 0.0]])
box("Group embedding", E_feat, "6")
# Step 1: Combine Training and Test Samples
X_combined = X_combined = Tensor(np.vstack([X_train.data, X_test.data]))
box("Training and Test Samples grouped", X_combined, "4")
# Step 1: Group Features
def group(X):
groups = X.shape[0] * W_enc.shape[1]
X_encoded = np.zeros((3, 2, 4))
# print(X_encoded)
idx = 0
col = 0
for (group_idx, row) in enumerate(X.data):
rt_ptr = 0
for rt_ptr in range(0, len(row), 2):
group_window = Tensor(row[rt_ptr:rt_ptr + 2])
group_matmul = group_window.matmul(W_enc_transpose) + b_enc[group_idx]
# group 1
if col == 0:
X_encoded[idx][0] = group_matmul.data + E_feat.data[0]
col = 1
# group 2
else:
X_encoded[idx][1] = group_matmul.data + + E_feat.data[1]
col = 0
box(f"grouping: group {col}", [group_window, group_matmul])
idx += 1
X_encoded_tensor = Tensor(X_encoded)
return X_encoded_tensor
X_encoded = group(X_combined)
box("X_encoded", X_encoded, "4")
# Label Encoder - Label Embeddings
W_y = Tensor([[1, -1, 0, 0], [0, 0, 1, 1]])
b_y = Tensor([0, 0, 0, 0])
y_padded = Tensor([1, 0, np.nan]) # we wan't to mask y_test with nan
y_clean = Tensor([[1, 0, 0], [0, 0, 1]]).reshape(3, 2)
box("y_clean", y_clean, "4")
def label_embeddings(Y_in: Tensor):
"""
Y_in: (n_total, 2) where columns are [y_clean, is_nan]
W_y: (2, D)
b_y: (D,)
returns: (n_total, D)
"""
D = b_y.shape[0]
if Y_in.shape[-1] != 2:
raise ValueError(f"Expected Y_in last dim=2, got shape {Y_in.shape}")
E = Y_in.matmul(W_y) + b_y # (n_total, D)
# optional: show each row embedding
for i in range(E.shape[0]):
box("Label Embedding", Tensor(E.data[i]), "5")
return E
label_embeds = label_embeddings(y_clean)
# print(label_embeds)
# Step 3: Add Thinking Tokens
Thinking_Tokens = Tensor([
[[0.01, 0.02, 0.03, 0.04],
[0.01, 0.02, 0.03, 0.04],
[0.01, 0.02, 0.03, 0.04]],
[[0.05, 0.06, 0.07, 0.08],
[0.05, 0.06, 0.07, 0.08],
[0.05, 0.06, 0.07, 0.08]]
])
box("Thinking Tokens", Thinking_Tokens, "4")
# Computing full model input
labels_reshaped = label_embeds.data.reshape(3, 1, 4)
data_rows = np.concatenate([X_encoded.data, labels_reshaped], axis=1)
E_numpy = np.concatenate([Thinking_Tokens.data, data_rows], axis=0)
E = Tensor(E_numpy)
# we need to adapt positional embeddings!
# Create row positional embeddings
P_col_pos_embeds = Tensor([[[0.1, 0.1, 0.1, 0.1],
[0.2, 0.2, 0.2, 0.2],
[0.3, 0.3, 0.3, 0.3]]])
# Add positional embeddings
E = E + P_col_pos_embeds
box("Positional Embedding", E, "9")
# Attention
W_q = Tensor(np.diag([0.1, 0.2, 0.1, 0.2]))
W_k = Tensor(np.diag([0.1, 0.1, 0.1, 0.1]))
W_v = Tensor(np.diag([1, 1, 1, 1]))
box("Attention weights", [W_q, W_k, W_v], "9")
scaling_factor = np.sqrt(4)
# labels = [E[1][2], E[2][2], E[2][2]]
col_att_softmax = Softmax()
def layer_norm_inplace(E: Tensor, eps=1e-5):
"""
In-place LN over last dim D for every vector in E.
E: (S, Ttok, D)
"""
x = E.data
mean = x.mean(axis=-1, keepdims=True)
var = ((x - mean) * (x - mean)).mean(axis=-1, keepdims=True)
x_norm = (x - mean) / np.sqrt(var + eps)
box("Layer norn", [Tensor(x), Tensor(mean), Tensor(var), Tensor(x_norm)], "7")
E.data[:] = x_norm
def column_attention_inplace(E: Tensor):
"""
In-place column attention:
For each item s: X = E[s] has shape (Ttok=3, D=4)
Does self-attention across the 3 tokens and writes back:
E[s] <- E[s] + Attn(E[s])
"""
S, Ttok, D = E.shape
softmax = Softmax()
for s in range(S):
# Snapshot of current item (avoid in-place mixing during compute)
X = Tensor(E.data[s].copy()) # (3,4)
Q = X.matmul(W_q.transpose()) # (3,4)
K = X.matmul(W_k.transpose()) # (3,4)
V = X.matmul(W_v.transpose()) # (3,4)
scores = Q.matmul(K.transpose()) / math.sqrt(D) # (3,3)
A = softmax.forward(scores, dim=-1) # (3,3)
O = A.matmul(V) # (3,4)
box("column_attention", [Q, K, V, scores, A, O], "5")
# In-place residual update of ALL tokens
E.data[s] = E.data[s] + O.data
column_attention_inplace(E)
layer_norm_inplace(E)
box("Updated Logits", E + 0, "5")
def mlp_inplace(E: Tensor):
"""
Minimal hand-friendly MLP with residual:
x <- x + GELU(x)
In-place.
"""
gelu = GELU()
x = Tensor(E.data.copy())
gx = gelu.forward(x).data
E.data[:] = E.data + gx
def row_attention_inplace(E: Tensor, single_eval_pos: int):
"""
In-place row attention:
For each token slot t:
Q from all S items: E[:, t, :] -> (S, D)
K,V from first Klen rows E[:single_eval_pos, t, :] -> (Klen, D)
Writes:
E[:, t, :] <- E[:, t, :] + Attn_row(E[:, t, :])
"""
S, Ttok, D = E.shape
softmax = Softmax()
Klen = single_eval_pos
assert 0 < Klen <= S, "single_eval_pos must be between 1 and S"
for t in range(Ttok):
# Snapshot streams (avoid in-place mixing)
X_all = Tensor(E.data[:, t, :].copy()) # (S, D)
X_kv = Tensor(E.data[:Klen, t, :].copy()) # (Klen, D)
Q = X_all.matmul(W_q.transpose()) # (S, D)
K = X_kv.matmul(W_k.transpose()) # (Klen, D)
V = X_kv.matmul(W_v.transpose()) # (Klen, D)
scores = Q.matmul(K.transpose()) / math.sqrt(D) # (S, Klen)
A = softmax.forward(scores, dim=-1) # (S, Klen)
O = A.matmul(V) # (S, D)
# In-place residual update for this token slot
box("row_attention", [Q, K, V, scores, A, O], "5")
E.data[:, t, :] = E.data[:, t, :] + O.data
row_attention_inplace(E, single_eval_pos=4)
layer_norm_inplace(E)
# 3) MLP + LN
mlp_inplace(E) # x <- x + GELU(x)
layer_norm_inplace(E)
# ============================================================
# Readout: take test row label token -> logits
# In this layout: rows are [think1, think2, train1, train2, test1]
# test index = T + N_train = 4
# label token index = 2
# ============================================================
test_row_idx = 4 # 4
label_tok_idx = 2 # last token slot
h_test = Tensor(E.data[test_row_idx, label_tok_idx, :].reshape(1, 4)) # (1,4)
gelu = GELU()
z = gelu.forward(h_test) # (1,4)
# Simple head D->C (pick first 2 dims as logits)
W_out = Tensor([[1, 0],
[0, 1],
[0, 0],
[0, 0]]) # (4,2)
b_out = Tensor([0.0, 0.0])
logits = z.matmul(W_out) + b_out # (1,2)
print("h_test:", h_test.data)
print("z (GELU):", z.data)
print("logits:", logits.data)
</textarea>
<button id="run">Run</button>
<div id="console-resizer"></div>
<div id="console-container">
<div id="console-header">
<span>CONSOLE OUTPUT</span>
<button id="console-clear" title="Clear console"></button>
</div>
<div id="console-output"></div>
</div>
<div id="error"></div>
</div>
<div class="v-resizer" data-left="editor"></div>
<div id="pdfSidebar">
<div id="pdfSidebarHeader">
<span class="pdf-title">PDF REF</span>
<label class="pdf-file-btn">Load PDF<input type="file" id="pdfFileInput" accept="application/pdf"></label>
<div class="pdf-controls">
<button id="pdfPrev" class="pdf-btn"></button>
<span id="pdfPageInfo">- / -</span>
<button id="pdfNext" class="pdf-btn"></button>
</div>
<div class="pdf-zoom">
<button id="pdfZoomOut" class="pdf-btn">-</button>
<span id="pdfZoomValue">100%</span>
<button id="pdfZoomIn" class="pdf-btn">+</button>
</div>
</div>
<div id="pdfViewerWrapper">
<canvas id="pdfCanvas"></canvas>
<div id="pdfPlaceholder">Load a PDF to view it here.</div>
</div>
</div>
</div>
<input type="file" id="jsonFileInput" accept="application/json" style="display:none">
<div id="note-modal-backdrop" class="note-modal-backdrop">
<div class="note-modal">
<div class="note-modal-title">Edit Note</div>
<textarea id="note-modal-textarea"></textarea>
<div style="font-size: 11px; color: #64748b; margin-bottom: 12px;">Supports Markdown and LaTeX: use $...$ for inline and $$...$$ for display math</div>
<div class="note-modal-actions">
<button id="note-modal-cancel" class="note-modal-btn">Cancel</button>
<button id="note-modal-delete" class="note-modal-btn">Delete</button>
<button id="note-modal-save" class="note-modal-btn">Save</button>
</div>
</div>
</div>
<script>
// ==================== PDF.js ====================
if (window.pdfjsLib) {
pdfjsLib.GlobalWorkerOptions.workerSrc = 'https://cdnjs.cloudflare.com/ajax/libs/pdf.js/3.11.174/pdf.worker.min.js';
}
let pdfDoc = null, pdfPageNum = 1, pdfZoomFactor = 1.0;
function renderPdfPage(num) {
if (!pdfDoc) return;
pdfDoc.getPage(num).then(page => {
const wrapper = document.getElementById('pdfViewerWrapper');
const scale = (wrapper.clientWidth - 20) / page.getViewport({ scale: 1.0 }).width * pdfZoomFactor;
const viewport = page.getViewport({ scale });
const canvas = document.getElementById('pdfCanvas');
canvas.width = viewport.width;
canvas.height = viewport.height;
page.render({ canvasContext: canvas.getContext('2d'), viewport });
canvas.style.display = 'block';
document.getElementById('pdfPlaceholder').style.display = 'none';
document.getElementById('pdfPageInfo').textContent = `${num} / ${pdfDoc.numPages}`;
document.getElementById('pdfZoomValue').textContent = Math.round(scale * 100) + '%';
});
}
document.getElementById('pdfFileInput').onchange = e => {
const file = e.target.files[0];
if (!file || !window.pdfjsLib) return;
pdfjsLib.getDocument(URL.createObjectURL(file)).promise.then(doc => {
pdfDoc = doc; pdfPageNum = 1; pdfZoomFactor = 1.0;
renderPdfPage(1);
});
};
document.getElementById('pdfPrev').onclick = () => { if (pdfDoc && pdfPageNum > 1) renderPdfPage(--pdfPageNum); };
document.getElementById('pdfNext').onclick = () => { if (pdfDoc && pdfPageNum < pdfDoc.numPages) renderPdfPage(++pdfPageNum); };
document.getElementById('pdfZoomIn').onclick = () => { pdfZoomFactor = Math.min(3, pdfZoomFactor * 1.1); renderPdfPage(pdfPageNum); };
document.getElementById('pdfZoomOut').onclick = () => { pdfZoomFactor = Math.max(0.3, pdfZoomFactor / 1.1); renderPdfPage(pdfPageNum); };
// ==================== WebSocket ====================
let ws = null, wsConnected = false;
function connectWebSocket() {
document.getElementById('wsStatus').className = 'status-indicator connecting';
ws = new WebSocket(`${location.protocol === 'https:' ? 'wss:' : 'ws:'}//${location.host}/ws`);
ws.onopen = () => { wsConnected = true; document.getElementById('wsStatus').className = 'status-indicator'; document.getElementById('run').disabled = false; };
ws.onclose = () => { wsConnected = false; document.getElementById('wsStatus').className = 'status-indicator disconnected'; document.getElementById('run').disabled = true; setTimeout(connectWebSocket, 2000); };
ws.onerror = () => { document.getElementById('wsStatus').className = 'status-indicator disconnected'; };
ws.onmessage = e => { try { handleMessage(JSON.parse(e.data)); } catch(err) { console.error(err); } };
}
// ==================== State ====================
const tensors = {};
const groups = [];
const pendingBoxes = [];
const boxes = [];
const canvasNotes = [];
let skipLayout = false; // Flag to preserve positions when loading from JSON
let currentTool = 'select';
let showMatrixData = true;
let zoomLevel = 1;
let nextGroupId = 1;
let nextBoxId = 1;
const selectedGroupIds = new Set();
const selectedBoxIds = new Set();
// ==================== Message Handling ====================
function handleMessage(msg) {
switch (msg.event) {
case 'reset':
Object.keys(tensors).forEach(k => delete tensors[k]);
groups.length = 0;
pendingBoxes.length = 0;
boxes.length = 0;
nextGroupId = 1;
nextBoxId = 1;
document.getElementById('error').textContent = '';
clearConsole();
break;
case 'tensor':
tensors[msg.id] = { id: msg.id, shape: msg.shape, data: msg.data, name: msg.name };
break;
case 'op':
// Store tensor IDs, not resolved tensors - we'll resolve at render time
// to get updated names from __auto_name__
groups.push({
id: 'g' + nextGroupId++,
opType: msg.type,
inputIds: msg.inputs,
outputId: msg.output,
meta: msg.meta || {},
x: 0, y: 0
});
break;
case 'box':
pendingBoxes.push({ label: msg.label, tensorIds: msg.tensors, scheme: msg.scheme || '1', parent: msg.parentBox });
break;
case 'error':
document.getElementById('error').textContent = msg.message;
appendToConsole(msg.message, 'error');
break;
case 'print':
appendToConsole(msg.text, msg.type || 'info');
break;
case 'done':
layoutAndRender();
break;
}
}
// ==================== Console Output ====================
function appendToConsole(text, type = 'info') {
const consoleOutput = document.getElementById('console-output');
const line = document.createElement('div');
line.className = 'console-line ' + type;
line.textContent = text;
consoleOutput.appendChild(line);
consoleOutput.scrollTop = consoleOutput.scrollHeight;
}
function clearConsole() {
document.getElementById('console-output').innerHTML = '';
}
// ==================== Layout ====================
function layoutAndRender() {
// First render pass: create elements to measure sizes, then create boxes
renderFirstPass();
}
// Create boxes from pending box() calls - following original implementation
function createPendingBoxes(groupElById) {
const PAD = 20; // Consistent padding on all sides
const LABEL_H = 24; // Height reserved for label
const CHILD_GAP_X = 30; // Gap between sibling child boxes
const COL_GAP = 30; // Gap between columns within a box
const CHILD_GAP_Y = 20; // Gap after each child
const ROW_GAP = 30;
const LEFT_MARGIN = 80;
// Track cursor position for laying out content
let layoutCursorY = 100;
// Track which groups have been claimed by boxes (to prevent double-counting)
const claimedGroupIds = new Set();
// Map tensor IDs to group IDs (only unclaimed groups)
// Also traces back to include ancestor operations that produced the inputs
// but stops at tensors that are explicitly listed in OTHER boxes
function getGroupIdsForTensors(tensorIds, includeAncestors = true) {
const gids = [];
const visited = new Set();
const tensorIdSet = new Set(tensorIds);
// Get all tensor IDs that are explicitly specified in OTHER boxes
// These act as "boundaries" - we don't trace ancestors past these
const otherBoxTensorIds = new Set();
pendingBoxes.forEach(pb => {
pb.tensorIds.forEach(tid => {
if (!tensorIdSet.has(tid)) {
otherBoxTensorIds.add(tid);
}
});
});
function addGroup(tid, isExplicit = false) {
if (visited.has(tid)) return;
visited.add(tid);
// Don't trace ancestors past tensors that belong to other boxes
// unless this tensor is explicitly listed in our box
if (!isExplicit && otherBoxTensorIds.has(tid)) {
return;
}
const g = groups.find(gr => gr.outputId === tid);
if (g && !gids.includes(g.id) && !claimedGroupIds.has(g.id)) {
// If includeAncestors, first add groups that produced this group's inputs
// This ensures operations are added in execution order (ancestors first)
if (includeAncestors && g.inputIds) {
g.inputIds.forEach(inputId => addGroup(inputId, false));
}
// Then add this group
gids.push(g.id);
}
}
tensorIds.forEach(tid => addGroup(tid, true));
return gids;
}
// Build parent-child relationships - detect automatically based on tensor containment
// A box is a parent of another if it contains ALL tensors of the child box
const boxesByLabel = {};
pendingBoxes.forEach(pb => {
boxesByLabel[pb.label] = pb;
pb.detectedChildren = []; // Track detected child boxes
});
// Detect parent-child relationships based on tensor containment
pendingBoxes.forEach(potentialParent => {
const parentTensorSet = new Set(potentialParent.tensorIds || []);
pendingBoxes.forEach(potentialChild => {
if (potentialParent === potentialChild) return;
const childTensors = potentialChild.tensorIds || [];
if (childTensors.length === 0) return;
// Check if parent contains ALL tensors of child
const allContained = childTensors.every(tid => parentTensorSet.has(tid));
// Also check parent has MORE tensors (not equal - otherwise both would be parents of each other)
if (allContained && parentTensorSet.size > childTensors.length) {
// potentialChild is a child of potentialParent
potentialChild.parent = potentialParent.label;
potentialParent.detectedChildren.push(potentialChild.label);
}
});
});
// Root boxes are those with no parent
const rootBoxes = pendingBoxes.filter(pb => !pb.parent);
// Recursive function to create box hierarchy
function createBoxHierarchy(pending, startY, startX) {
const contentTop = startY + LABEL_H + PAD;
// Find children of this box
const children = pendingBoxes.filter(p => p.parent === pending.label);
const allGroupIds = getGroupIdsForTensors(pending.tensorIds);
let minX = Infinity, minY = Infinity, maxX = -Infinity, maxY = -Infinity;
let cursorY = contentTop;
let cursorX = startX + PAD;
function includeRect(x1, y1, x2, y2) {
minX = Math.min(minX, x1);
minY = Math.min(minY, y1);
maxX = Math.max(maxX, x2);
maxY = Math.max(maxY, y2);
}
// Create this box
const box = {
id: 'b' + nextBoxId++,
label: pending.label,
scheme: pending.scheme || '1',
x: startX,
y: startY,
w: 100,
h: LABEL_H + 50,
groupIds: [],
childBoxIds: [],
fromCode: true
};
boxes.push(box);
// 1) Create children first - position them SIDE BY SIDE
const childGroupIdSet = new Set();
let childMaxHeight = 0;
children.forEach(childPending => {
const child = createBoxHierarchy(childPending, contentTop, cursorX);
if (!child) return;
box.childBoxIds.push(child.id);
// Collect all groups used by this child
(child.groupIds || []).forEach(gid => childGroupIdSet.add(gid));
// Also collect groups from grandchildren recursively
function collectChildGroups(b) {
(b.groupIds || []).forEach(gid => childGroupIdSet.add(gid));
(b.childBoxIds || []).forEach(cid => {
const cb = boxes.find(bb => bb.id === cid);
if (cb) collectChildGroups(cb);
});
}
collectChildGroups(child);
includeRect(child.x, child.y, child.x + child.w, child.y + child.h);
childMaxHeight = Math.max(childMaxHeight, child.h);
cursorX = child.x + child.w + CHILD_GAP_X;
});
if (children.length > 0) {
cursorY = contentTop + childMaxHeight + CHILD_GAP_Y;
}
// 2) Layout this box's OWN groups (exclude groups used by children)
const ownGroupIds = allGroupIds.filter(id => !childGroupIdSet.has(id));
if (ownGroupIds.length > 0) {
// Layout in two columns
const nLeft = Math.ceil(ownGroupIds.length / 2);
const leftIds = ownGroupIds.slice(0, nLeft);
const rightIds = ownGroupIds.slice(nLeft);
let leftX = startX + PAD;
let rightX = leftX;
let yL = cursorY, yR = cursorY;
let maxLeftW = 0, maxRightW = 0;
// Measure and position left column
leftIds.forEach(gid => {
const g = groups.find(gr => gr.id === gid);
const el = groupElById[gid];
if (g && el) {
g.x = leftX;
g.y = yL;
el.style.left = g.x + 'px';
el.style.top = g.y + 'px';
// Use offsetWidth/Height (unaffected by zoom transform)
const w = el.offsetWidth;
const h = el.offsetHeight;
maxLeftW = Math.max(maxLeftW, w);
includeRect(g.x, g.y, g.x + w, g.y + h);
yL += h + ROW_GAP;
}
});
rightX = leftX + maxLeftW + COL_GAP;
// Measure and position right column
rightIds.forEach(gid => {
const g = groups.find(gr => gr.id === gid);
const el = groupElById[gid];
if (g && el) {
g.x = rightX;
g.y = yR;
el.style.left = g.x + 'px';
el.style.top = g.y + 'px';
// Use offsetWidth/Height (unaffected by zoom transform)
const w = el.offsetWidth;
const h = el.offsetHeight;
maxRightW = Math.max(maxRightW, w);
includeRect(g.x, g.y, g.x + w, g.y + h);
yR += h + ROW_GAP;
}
});
box.groupIds = ownGroupIds.slice();
// Mark these groups as claimed so other boxes don't reuse them
ownGroupIds.forEach(gid => claimedGroupIds.add(gid));
cursorY = Math.max(yL, yR);
}
// 3) Finalize box bounds
if (isFinite(minX)) {
box.x = minX - PAD;
box.y = minY - PAD - LABEL_H;
box.w = (maxX - minX) + 2 * PAD;
box.h = (maxY - minY) + 2 * PAD + LABEL_H;
}
return box;
}
// Only process pending boxes if there are any
if (pendingBoxes.length > 0) {
// Create root boxes
rootBoxes.forEach(rootPending => {
const rootBox = createBoxHierarchy(rootPending, layoutCursorY, LEFT_MARGIN);
if (rootBox) {
layoutCursorY = rootBox.y + rootBox.h + 50;
}
});
pendingBoxes.length = 0;
}
// Layout any orphan groups (not in any box) - this runs ALWAYS
const groupsInBoxes = new Set();
boxes.forEach(b => (b.groupIds || []).forEach(gid => groupsInBoxes.add(gid)));
const orphanGroupIds = groups.filter(g => !groupsInBoxes.has(g.id)).map(g => g.id);
if (orphanGroupIds.length > 0) {
layoutGroupsTwoColumns(orphanGroupIds, groupElById, layoutCursorY, LEFT_MARGIN, true);
}
}
// Layout groups in two columns with measured sizes
// If forceLayout is false, groups with existing positions are kept
function layoutGroupsTwoColumns(groupIds, groupElById, startY, leftMargin, forceLayout = true) {
const COL_GAP = 80; // Gap between columns
const ROW_GAP = 50; // Gap between rows
if (!groupIds.length) return null;
// Measure all groups
const sizes = {};
let maxLeftW = 0, maxRightW = 0;
// Split into two columns: left gets ceil(n/2)
const nLeft = Math.ceil(groupIds.length / 2);
const leftIds = groupIds.slice(0, nLeft);
const rightIds = groupIds.slice(nLeft);
// Measure left column
leftIds.forEach(id => {
const el = groupElById[id];
const w = el ? el.offsetWidth : 200;
const h = el ? el.offsetHeight : 150;
sizes[id] = { w, h };
maxLeftW = Math.max(maxLeftW, w);
});
// Measure right column
rightIds.forEach(id => {
const el = groupElById[id];
const w = el ? el.offsetWidth : 200;
const h = el ? el.offsetHeight : 150;
sizes[id] = { w, h };
maxRightW = Math.max(maxRightW, w);
});
const leftX = leftMargin;
const rightX = leftX + maxLeftW + COL_GAP;
// Position left column
let yL = startY;
leftIds.forEach(id => {
const g = groups.find(gr => gr.id === id);
if (!g) return;
// Only set position if not already positioned or forceLayout
if (forceLayout || (g.x === undefined || g.x === 0)) {
g.x = leftX;
g.y = yL;
}
const el = groupElById[id];
if (el) {
el.style.left = g.x + 'px';
el.style.top = g.y + 'px';
}
yL += sizes[id].h + ROW_GAP;
});
// Position right column
let yR = startY;
rightIds.forEach(id => {
const g = groups.find(gr => gr.id === id);
if (!g) return;
// Only set position if not already positioned or forceLayout
if (forceLayout || (g.x === undefined || g.x === 0)) {
g.x = rightX;
g.y = yR;
}
const el = groupElById[id];
if (el) {
el.style.left = g.x + 'px';
el.style.top = g.y + 'px';
}
yR += sizes[id].h + ROW_GAP;
});
// Compute bounding box based on actual positions
let minX = Infinity, minY = Infinity, maxX = -Infinity, maxY = -Infinity;
groupIds.forEach(id => {
const g = groups.find(gr => gr.id === id);
const sz = sizes[id];
if (!g || !sz) return;
minX = Math.min(minX, g.x);
minY = Math.min(minY, g.y);
maxX = Math.max(maxX, g.x + sz.w);
maxY = Math.max(maxY, g.y + sz.h);
});
return { minX, minY, maxX, maxY, lastY: Math.max(yL, yR) };
}
// First pass: render groups without layout to measure
function renderFirstPass() {
const canvas = document.getElementById('canvas');
canvas.innerHTML = '';
// Add toolbar
canvas.appendChild(createToolbar());
const content = document.createElement('div');
content.id = 'canvas-content';
canvas.appendChild(content);
const groupElById = {};
// Helper to generate label from group data using current tensor names
function getGroupLabel(g) {
// Helper to get tensor or create placeholder for missing tensor
function getOrPlaceholder(id) {
if (tensors[id]) return tensors[id];
return { id, shape: [], data: [], name: id };
}
const output = getOrPlaceholder(g.outputId);
const inputs = (g.inputIds || []).map(id => getOrPlaceholder(id));
const outputName = output?.name || g.opType;
// For linear layers, only show the original input (not weight/bias)
// The instrumentation adds weight at index 0, so original input is at index 1
if (g.opType.toLowerCase() === 'linear' && inputs.length >= 2) {
const originalInput = inputs[1]; // Skip weight at index 0
const inputName = originalInput?.name || originalInput?.id || 'x';
return `${outputName} = Linear(${inputName})`;
}
// Helper to get display name for an input tensor
function getInputDisplayName(tensor) {
if (tensor.name) return tensor.name;
// Look for a group that produced this tensor
const producerGroup = groups.find(gr => gr.outputId === tensor.id);
if (producerGroup) {
// If producer is linear, use "y" (what linear displays its output as)
if (producerGroup.opType.toLowerCase() === 'linear') {
const producerOutput = tensors[producerGroup.outputId];
return producerOutput?.name || 'y';
}
// For other producers, try their output tensor name
const producerOutput = tensors[producerGroup.outputId];
if (producerOutput?.name) return producerOutput.name;
}
return tensor.id;
}
const inputNames = inputs.map(t => getInputDisplayName(t)).join(', ');
return `${outputName} = ${g.opType}(${inputNames})`;
}
// Helper to build op object from group data
function getGroupOp(g) {
// Create placeholder tensor if missing (for backward-compatible JSON loading)
function getOrPlaceholder(id) {
if (tensors[id]) return tensors[id];
// Create placeholder with just id (will render as empty matrix)
return { id, shape: [], data: [], name: id };
}
const output = getOrPlaceholder(g.outputId);
const inputs = (g.inputIds || []).map(id => getOrPlaceholder(id));
return { type: g.opType, inputs, output, meta: g.meta || {} };
}
// Render all groups at temporary positions to measure
groups.forEach(g => {
const el = document.createElement('div');
el.className = 'group-container';
el.dataset.groupId = g.id;
el.style.left = '0px';
el.style.top = '0px';
el.style.visibility = 'hidden'; // Hidden during measurement
const label = document.createElement('div');
label.className = 'group-label';
label.textContent = getGroupLabel(g);
el.appendChild(label);
renderOp(el, getGroupOp(g));
content.appendChild(el);
groupElById[g.id] = el;
});
// Force layout calculation
content.offsetHeight;
// If skipLayout is set, just position groups using their saved coordinates
if (skipLayout) {
// Find minimum Y to calculate offset if needed (ensure content is below toolbar)
const MIN_Y = 80; // Minimum Y to ensure content is visible below toolbar
let minLoadedY = Infinity;
groups.forEach(g => {
if (typeof g.y === 'number') minLoadedY = Math.min(minLoadedY, g.y);
});
const yOffset = (minLoadedY < MIN_Y && minLoadedY !== Infinity) ? (MIN_Y - minLoadedY) : 0;
// Use saved positions - no layout calculation
groups.forEach(g => {
const el = groupElById[g.id];
if (el) {
const xPos = typeof g.x === 'number' ? g.x : 0;
const yPos = (typeof g.y === 'number' ? g.y : 0) + yOffset;
el.style.left = xPos + 'px';
el.style.top = yPos + 'px';
// Update the group object with adjusted position
g.x = xPos;
g.y = yPos;
}
});
// Boxes already have their x, y, w, h from JSON - apply same offset
boxes.forEach(box => {
if (typeof box.x !== 'number') box.x = 0;
if (typeof box.y !== 'number') box.y = 0;
else box.y = box.y + yOffset; // Apply same offset
if (typeof box.w !== 'number' || box.w <= 0) box.w = 100;
if (typeof box.h !== 'number' || box.h <= 0) box.h = 100;
});
skipLayout = false; // Reset flag
} else {
// Create boxes from pending box() calls - this handles all layout
createPendingBoxes(groupElById);
}
// Make groups visible and attach behaviors
groups.forEach(g => {
const el = groupElById[g.id];
if (el) {
el.style.visibility = 'visible';
el.style.left = g.x + 'px';
el.style.top = g.y + 'px';
}
});
// Now do final render with boxes
renderFinal(groupElById);
}
function createToolbar() {
const toolbar = document.createElement('div');
toolbar.id = 'toolbar';
function addBtn(label, onClick, isTool, toolName) {
const btn = document.createElement('button');
btn.textContent = label;
btn.className = 'tool-btn' + (isTool && currentTool === toolName ? ' active' : '');
btn.onclick = onClick;
toolbar.appendChild(btn);
}
addBtn('Select', () => { currentTool = 'select'; selectedGroupIds.clear(); selectedBoxIds.clear(); render(); }, true, 'select');
addBtn('Add Note', addNote);
addBtn('Export JSON', exportJson);
addBtn('Load JSON', () => document.getElementById('jsonFileInput').click());
const vg = document.createElement('div');
vg.className = 'tool-group';
vg.innerHTML = '<span class="tool-label">View</span>';
const sel = document.createElement('select');
sel.className = 'tool-select';
[25, 50, 75, 100, 125, 150].forEach(p => {
const opt = document.createElement('option');
opt.value = p; opt.textContent = p + '%';
if (p === Math.round(zoomLevel * 100)) opt.selected = true;
sel.appendChild(opt);
});
sel.onchange = () => { zoomLevel = parseInt(sel.value) / 100; applyZoom(); };
vg.appendChild(sel);
const dt = document.createElement('label');
dt.className = 'tool-toggle';
const cb = document.createElement('input');
cb.type = 'checkbox'; cb.checked = showMatrixData;
cb.onchange = () => { showMatrixData = cb.checked; render(); };
dt.appendChild(cb);
dt.appendChild(document.createTextNode(' Data'));
vg.appendChild(dt);
// Retro mode toggle
const rt = document.createElement('label');
rt.className = 'tool-toggle';
rt.style.marginLeft = '12px';
const rcb = document.createElement('input');
rcb.type = 'checkbox';
rcb.checked = document.body.classList.contains('retro-mode');
rcb.onchange = () => {
document.body.classList.toggle('retro-mode', rcb.checked);
localStorage.setItem('retroMode', rcb.checked ? '1' : '0');
};
rt.appendChild(rcb);
rt.appendChild(document.createTextNode(' Retro'));
vg.appendChild(rt);
toolbar.appendChild(vg);
return toolbar;
}
function renderFinal(groupElById) {
const content = document.getElementById('canvas-content');
// Sort boxes: parent boxes first (larger), child boxes last (smaller)
// This ensures parent boxes are rendered behind child boxes
const sortedBoxes = [...boxes].sort((a, b) => {
// Boxes with more area go first (behind)
const areaA = (a.w || 0) * (a.h || 0);
const areaB = (b.w || 0) * (b.h || 0);
return areaB - areaA; // Larger boxes first
});
// Draw boxes - larger ones first (behind), smaller ones after (in front)
sortedBoxes.forEach(box => {
if (box.w <= 0 || box.h <= 0) return;
const colors = getBoxColors(box.scheme);
const el = document.createElement('div');
el.className = 'layer-box';
el.dataset.boxId = box.id;
el.style.left = box.x + 'px';
el.style.top = box.y + 'px';
el.style.width = box.w + 'px';
el.style.height = box.h + 'px';
el.style.borderColor = colors.border;
el.style.background = colors.bg;
if (selectedBoxIds.has(box.id)) el.classList.add('layer-box-selected');
const lbl = document.createElement('div');
lbl.className = 'layer-box-label';
lbl.textContent = box.label;
lbl.style.color = colors.label;
el.appendChild(lbl);
attachBoxDrag(box, el, lbl, groupElById);
content.insertBefore(el, content.firstChild);
});
// Attach drag behavior to groups
groups.forEach(g => {
const el = groupElById[g.id];
if (!el) return;
const labelEl = el.querySelector('.group-label');
if (labelEl) attachGroupDrag(g, el, labelEl);
});
// Render notes
canvasNotes.forEach(note => {
const el = document.createElement('div');
el.className = 'note-container';
el.style.left = note.x + 'px';
el.style.top = note.y + 'px';
const header = document.createElement('div');
header.className = 'note-header';
header.textContent = 'NOTE';
el.appendChild(header);
const body = document.createElement('div');
body.className = 'note-body';
// Parse markdown and then render LaTeX
let html = window.marked ? marked.parse(note.text) : note.text;
body.innerHTML = html;
// Render LaTeX with KaTeX
renderLatexInElement(body);
el.appendChild(body);
attachNoteDrag(note, el, header);
el.ondblclick = () => openNoteModal(note);
content.appendChild(el);
});
// Click canvas to deselect
content.onclick = e => {
if (e.target === content && currentTool === 'select') {
selectedGroupIds.clear();
selectedBoxIds.clear();
render();
}
};
applyZoom();
// Set up tensor flow highlighting after render is complete
setupTensorFlowHighlighting();
}
// ==================== Tensor to 2D ====================
// orientation: 'auto' (default), 'row' (1×N), or 'col' (N×1) for 1D tensors
function tensorTo2D(t, orientation = 'auto') {
if (!t || !t.data) return { mat: [[]], shape: [] };
const flat = [];
(function f(x) { Array.isArray(x) ? x.forEach(f) : flat.push(x); })(t.data);
let shape = t.shape || [];
if (shape.length === 0) shape = [1, 1];
else if (shape.length === 1) {
// For 1D tensors, use orientation hint
if (orientation === 'row') {
shape = [1, shape[0]]; // Display as row vector (1×N)
} else {
shape = [shape[0], 1]; // Display as column vector (N×1) - default
}
}
const [rows, cols] = [shape[0], shape[1]];
const mat = [];
for (let r = 0, i = 0; r < rows; r++) {
const row = [];
for (let c = 0; c < cols; c++) {
const v = flat[i++];
row.push(typeof v === 'number' && isFinite(v) ? v.toFixed(4) : String(v ?? ''));
}
mat.push(row);
}
return { mat, shape };
}
function createMatrixCard(label, tensor, orientation = 'auto', role = null) {
const { mat, shape: displayShape } = tensorTo2D(tensor, orientation);
// Use display shape (2D) for consistent formatting - always shows rows×cols
const shapeStr = displayShape.join('×');
const card = document.createElement('div');
card.className = 'matrix-card';
if (role) card.dataset.role = role;
// Store tensor ID and name for flow highlighting
if (tensor && (tensor.id !== undefined && tensor.id !== null)) {
card.dataset.tensorId = String(tensor.id);
}
// Also store tensor name for name-based matching
if (tensor && tensor.name) {
card.dataset.tensorName = tensor.name;
}
card.innerHTML = `<div class="matrix-title"><span>${label}</span><span class="matrix-shape">${shapeStr}</span></div>`;
if (showMatrixData && mat.length > 0 && mat[0].length > 0) {
const table = document.createElement('table');
table.className = 'matrix-table';
const MAX_ROWS = 4, MAX_COLS = 6;
const totalR = mat.length, totalC = mat[0].length;
const showR = Math.min(totalR, MAX_ROWS), showC = Math.min(totalC, MAX_COLS);
const truncatedR = totalR > MAX_ROWS, truncatedC = totalC > MAX_COLS;
for (let r = 0; r < showR; r++) {
const tr = document.createElement('tr');
tr.dataset.row = r;
for (let c = 0; c < showC; c++) {
const td = document.createElement('td');
td.textContent = mat[r][c];
td.dataset.row = r;
td.dataset.col = c;
tr.appendChild(td);
}
// Add column ellipsis if truncated
if (truncatedC) {
const td = document.createElement('td');
td.textContent = '…';
td.style.color = '#64748b';
tr.appendChild(td);
}
table.appendChild(tr);
}
// Add row ellipsis if truncated
if (truncatedR) {
const tr = document.createElement('tr');
for (let c = 0; c < showC + (truncatedC ? 1 : 0); c++) {
const td = document.createElement('td');
td.textContent = c === 0 ? '⋮' : '';
td.style.color = '#64748b';
tr.appendChild(td);
}
table.appendChild(tr);
}
card.appendChild(table);
}
return card;
}
// ==================== Element-wise Hover Highlighting ====================
// For operations where output[i,j] corresponds to input[i,j] (ReLU, activations, etc.)
// Handles NumPy broadcasting: 1D tensors broadcast as rows (1, N) or columns (N, 1)
function setupElementwiseHover(inputCards, outputCard) {
const outputTable = outputCard.querySelector('.matrix-table');
if (!outputTable) return;
const inputTables = inputCards.map(card => card.querySelector('.matrix-table')).filter(Boolean);
if (inputTables.length === 0) return;
// Detect the shape of each input table (for broadcasting)
const inputShapes = inputTables.map(table => {
const rows = table.querySelectorAll('tr');
const numRows = rows.length;
const numCols = rows[0] ? rows[0].querySelectorAll('td[data-col]').length : 0;
return { rows: numRows, cols: numCols };
});
const outputCells = outputTable.querySelectorAll('td[data-row][data-col]');
outputCells.forEach(cell => {
cell.style.cursor = 'pointer';
cell.addEventListener('mouseenter', () => {
const row = parseInt(cell.dataset.row);
const col = parseInt(cell.dataset.col);
// Highlight the output cell
cell.classList.add('highlight-cell');
// Highlight corresponding cells in all input matrices (with broadcasting)
inputTables.forEach((inputTable, idx) => {
const shape = inputShapes[idx];
// Handle broadcasting: if input has only 1 row, always use row 0
// If input has only 1 col, always use col 0
const targetRow = shape.rows === 1 ? 0 : row;
const targetCol = shape.cols === 1 ? 0 : col;
const inputCell = inputTable.querySelector(`td[data-row="${targetRow}"][data-col="${targetCol}"]`);
if (inputCell) {
inputCell.classList.add('highlight-cell');
}
});
});
cell.addEventListener('mouseleave', () => {
// Remove all highlights
cell.classList.remove('highlight-cell');
inputTables.forEach(inputTable => {
inputTable.querySelectorAll('.highlight-cell').forEach(td => {
td.classList.remove('highlight-cell');
});
});
});
});
}
// ==================== Transpose Hover Highlighting ====================
// For transpose, output[i,j] corresponds to input[j,i]
function setupTransposeHover(inputCard, outputCard) {
const outputTable = outputCard.querySelector('.matrix-table');
const inputTable = inputCard.querySelector('.matrix-table');
if (!outputTable || !inputTable) return;
const outputCells = outputTable.querySelectorAll('td[data-row][data-col]');
outputCells.forEach(cell => {
cell.style.cursor = 'pointer';
cell.addEventListener('mouseenter', () => {
const row = cell.dataset.row;
const col = cell.dataset.col;
// Highlight the output cell
cell.classList.add('highlight-cell');
// For transpose: input[col, row] -> output[row, col]
// So to find the source, swap row and col
const inputCell = inputTable.querySelector(`td[data-row="${col}"][data-col="${row}"]`);
if (inputCell) {
inputCell.classList.add('highlight-cell');
}
});
cell.addEventListener('mouseleave', () => {
// Remove all highlights
cell.classList.remove('highlight-cell');
inputTable.querySelectorAll('td.highlight-cell').forEach(td => {
td.classList.remove('highlight-cell');
});
});
});
}
// ==================== Matmul Hover Highlighting ====================
function setupMatmulHover(wrapper, leftCard, topCard, resultCard) {
const resultTable = resultCard.querySelector('.matrix-table');
const leftTable = leftCard.querySelector('.matrix-table');
const topTable = topCard.querySelector('.matrix-table');
if (!resultTable || !leftTable || !topTable) return;
const resultCells = resultTable.querySelectorAll('td[data-row][data-col]');
resultCells.forEach(cell => {
cell.style.cursor = 'pointer';
cell.addEventListener('mouseenter', () => {
const row = cell.dataset.row;
const col = cell.dataset.col;
// Highlight the result cell
cell.classList.add('highlight-cell');
// Highlight the corresponding row in the left matrix (W)
leftTable.querySelectorAll(`td[data-row="${row}"]`).forEach(td => {
td.classList.add('highlight-row');
});
// Highlight the corresponding column in the top matrix (X.T)
topTable.querySelectorAll(`td[data-col="${col}"]`).forEach(td => {
td.classList.add('highlight-col');
});
});
cell.addEventListener('mouseleave', () => {
// Remove all highlights
cell.classList.remove('highlight-cell');
leftTable.querySelectorAll('.highlight-row').forEach(td => {
td.classList.remove('highlight-row');
});
topTable.querySelectorAll('.highlight-col').forEach(td => {
td.classList.remove('highlight-col');
});
});
});
}
// ==================== Tensor Flow Highlighting ====================
// Highlights all occurrences of a tensor when hovering over any matrix card
function setupTensorFlowHighlighting() {
const canvasContent = document.getElementById('canvas-content');
if (!canvasContent) return;
// Get all matrix cards
const allCards = Array.from(canvasContent.querySelectorAll('.matrix-card'));
// Build map using the displayed label (variable name) for matching
const labelCardMap = {}; // label -> [cards]
allCards.forEach(card => {
// Get the displayed label from the title
const titleSpan = card.querySelector('.matrix-title span');
if (!titleSpan) return;
let label = titleSpan.textContent.trim();
// Remove .T suffix for matching (Q.T should match Q)
const baseLabel = label.replace(/\.T$/, '');
if (!labelCardMap[baseLabel]) labelCardMap[baseLabel] = [];
labelCardMap[baseLabel].push(card);
});
// Find labels that appear multiple times
const multiUseLabels = Object.entries(labelCardMap)
.filter(([_, cards]) => cards.length > 1)
.map(([label, cards]) => `${label}(${cards.length})`);
if (multiUseLabels.length > 0) {
console.log('Tensor flow: hover to highlight:', multiUseLabels.join(', '));
}
// Set up hover for cards that share a label
allCards.forEach(card => {
const titleSpan = card.querySelector('.matrix-title span');
if (!titleSpan) return;
let label = titleSpan.textContent.trim();
const baseLabel = label.replace(/\.T$/, '');
const matchingCards = labelCardMap[baseLabel] || [];
// Only set up hover if this label appears multiple times
if (matchingCards.length > 1) {
card.style.cursor = 'pointer';
card.addEventListener('mouseenter', () => {
matchingCards.forEach(c => c.classList.add('tensor-flow-highlight'));
});
card.addEventListener('mouseleave', () => {
matchingCards.forEach(c => c.classList.remove('tensor-flow-highlight'));
});
}
});
}
// ==================== Render Operation ====================
function renderOp(container, op) {
const { type, inputs, output, meta } = op;
// Safe helper to get tensor name with fallback
function getName(t, fallback = '?') {
if (!t) return fallback;
return t.name || t.id || fallback;
}
// Ensure inputs is always an array
const safeInputs = inputs || [];
const safeOutput = output || { id: '?', name: type, shape: [], data: [] };
const isElementwise = ['add', 'sub', 'mul', 'div'].includes(type) && safeInputs.length >= 2;
const isMatmul = type === 'matmul' && safeInputs.length >= 2;
const isLoss = ['mseloss', 'crossentropyloss', 'bceloss'].includes(type) && safeInputs.length >= 2;
const isLinear = type === 'linear' && safeInputs.length >= 2 && meta?.has_weight;
// Determine output orientation for reduction operations
let outputOrientation = 'auto';
if (['sum', 'mean', 'max'].includes(type) && safeInputs[0] && safeOutput.shape && safeOutput.shape.length === 1) {
const inputShape = safeInputs[0].shape || [];
const axis = meta?.axis ?? meta?.arg0;
if (inputShape.length >= 2) {
let normalizedAxis = axis;
if (normalizedAxis < 0) normalizedAxis = inputShape.length + normalizedAxis;
if (normalizedAxis === 0) {
outputOrientation = 'row';
} else {
outputOrientation = 'col';
}
}
}
// Get operator symbol for element-wise operations
const opSymbols = { add: '+', sub: '−', mul: '×', div: '÷' };
if (isElementwise && safeInputs[0] && safeInputs[1]) {
// Element-wise: inputs side by side, result below left input
const wrapper = document.createElement('div');
wrapper.className = 'layout-elementwise';
// Determine orientation for 1D tensors based on NumPy broadcasting rules:
// A 1D tensor (N,) broadcasts as (1, N) - a ROW, not a column
// This affects how we display the second operand (typically bias)
const input0Shape = safeInputs[0].shape || [];
const input1Shape = safeInputs[1].shape || [];
// If second input is 1D and first is 2D, display second as ROW (how NumPy broadcasts it)
let rightOrientation = 'auto';
if (input1Shape.length === 1 && input0Shape.length >= 2) {
rightOrientation = 'row'; // 1D broadcasts as row in NumPy
}
// First operand (top-left)
const left = createMatrixCard(getName(safeInputs[0]), safeInputs[0], 'auto', 'input');
left.classList.add('elem-left');
wrapper.appendChild(left);
// Operator symbol (top-center)
const opSym = document.createElement('span');
opSym.className = 'op-symbol-large elem-op';
opSym.textContent = opSymbols[type] || '+';
wrapper.appendChild(opSym);
// Second operand (top-right) - use row orientation for 1D bias
const right = createMatrixCard(getName(safeInputs[1]), safeInputs[1], rightOrientation, 'input');
right.classList.add('elem-right');
wrapper.appendChild(right);
// Result (bottom-left, below first operand)
const result = createMatrixCard(getName(safeOutput, type), safeOutput, outputOrientation, 'output');
result.classList.add('elem-result');
wrapper.appendChild(result);
// Add element-wise hover highlighting
setupElementwiseHover([left, right], result);
container.appendChild(wrapper);
return;
}
if (isLoss && safeInputs[0] && safeInputs[1]) {
// Loss functions: predictions | targets, loss below
const wrapper = document.createElement('div');
wrapper.className = 'layout-elementwise';
// Predictions (top-left)
const preds = createMatrixCard(getName(safeInputs[0], 'predictions'), safeInputs[0], 'auto', 'input');
preds.classList.add('elem-left');
wrapper.appendChild(preds);
// Arrow or separator (optional visual)
const arrow = document.createElement('span');
arrow.className = 'op-symbol-large elem-op';
arrow.textContent = '→';
arrow.title = 'compared to';
wrapper.appendChild(arrow);
// Targets (top-right)
const targs = createMatrixCard(getName(safeInputs[1], 'targets'), safeInputs[1], 'auto', 'input');
targs.classList.add('elem-right');
wrapper.appendChild(targs);
// Loss value (bottom-left)
const lossCard = createMatrixCard(getName(safeOutput, 'loss'), safeOutput, 'auto', 'output');
lossCard.classList.add('elem-result');
wrapper.appendChild(lossCard);
// Add element-wise hover highlighting
setupElementwiseHover([preds, targs], lossCard);
container.appendChild(wrapper);
return;
}
if (isLinear && safeInputs[0] && safeInputs[1]) {
// Linear layer: show as matmul layout for W @ X.T
// Instrumentation inserts weight at front: inputs[0] = weight, inputs[1] = x
const weight = safeInputs[0];
const x = safeInputs[1];
// Create transposed version of x for visualization
// The computation is (W @ x.T).T, so we show x.T on top
const xT = {
...(x || {}),
name: getName(x, 'X') + '.T',
shape: x?.shape ? [...x.shape].reverse() : [],
data: x?.data || []
};
// Transpose the actual data for display
if (x?.data && Array.isArray(x.data) && x.data.length > 0) {
if (Array.isArray(x.data[0])) {
xT.data = x.data[0].map((_, colIdx) => x.data.map(row => row[colIdx]));
}
}
// Create transposed version of output for visualization
const outputT = {
...(safeOutput || {}),
name: getName(safeOutput, 'y'),
shape: safeOutput?.shape ? [...safeOutput.shape].reverse() : [],
data: safeOutput?.data || []
};
// Transpose the output data for display
if (safeOutput?.data && Array.isArray(safeOutput.data) && safeOutput.data.length > 0) {
if (Array.isArray(safeOutput.data[0])) {
outputT.data = safeOutput.data[0].map((_, colIdx) => safeOutput.data.map(row => row[colIdx]));
}
}
const wrapper = document.createElement('div');
wrapper.className = 'layout-linear';
// Empty cell for top-left (matmul alignment)
const emptyCell = document.createElement('div');
emptyCell.className = 'linear-empty';
wrapper.appendChild(emptyCell);
// Input X.T (top-right position) - with role for hover highlighting
const inputCard = createMatrixCard(xT.name, xT, 'auto', 'top');
inputCard.classList.add('linear-input');
wrapper.appendChild(inputCard);
// Weight matrix W (left position)
const weightCard = createMatrixCard('W', weight, 'auto', 'left');
weightCard.classList.add('linear-weight');
wrapper.appendChild(weightCard);
// Output (bottom-right) - show transposed to match W @ X.T dimensions
const outputCard = createMatrixCard(outputT.name, outputT, 'auto', 'result');
outputCard.classList.add('linear-output');
wrapper.appendChild(outputCard);
// Add hover highlighting for matmul visualization
setupMatmulHover(wrapper, weightCard, inputCard, outputCard);
container.appendChild(wrapper);
return;
}
if (isMatmul && safeInputs[0] && safeInputs[1]) {
// Matrix multiplication: keep grid layout for proper alignment
const grid = document.createElement('div');
grid.className = 'layout-grid layout-binary';
const left = createMatrixCard(getName(safeInputs[0]), safeInputs[0], 'auto', 'left');
left.classList.add('pos-left');
grid.appendChild(left);
const top = createMatrixCard(getName(safeInputs[1]), safeInputs[1], 'auto', 'top');
top.classList.add('pos-top');
grid.appendChild(top);
const res = createMatrixCard(getName(safeOutput, type), safeOutput, outputOrientation, 'result');
res.classList.add('pos-result');
grid.appendChild(res);
// Add hover highlighting for matmul visualization
setupMatmulHover(grid, left, top, res);
container.appendChild(grid);
return;
}
// Unary operations (sum, mean, max, transpose, relu, sigmoid, etc.)
const grid = document.createElement('div');
grid.className = 'layout-grid layout-unary';
// Helper to get display name for a tensor, falling back to producer's output name
function getInputDisplayName(tensor) {
if (!tensor) return '?';
if (tensor.name) return tensor.name;
// Look for a group that produced this tensor
const producerGroup = groups.find(g => g.outputId === tensor.id);
if (producerGroup) {
// If producer is linear, use "y" (what linear displays its output as)
if (producerGroup.opType.toLowerCase() === 'linear') {
const producerOutput = tensors[producerGroup.outputId];
return producerOutput?.name || 'y';
}
// For other producers, try their output tensor name
const producerOutput = tensors[producerGroup.outputId];
if (producerOutput?.name) return producerOutput.name;
}
return tensor.id;
}
const inputCards = [];
if (safeInputs[0]) {
const inputLabel = getInputDisplayName(safeInputs[0]);
const inputCard = createMatrixCard(inputLabel, safeInputs[0], 'auto', 'input');
grid.appendChild(inputCard);
inputCards.push(inputCard);
}
const outputCard = createMatrixCard(getName(safeOutput, type), safeOutput, outputOrientation, 'output');
grid.appendChild(outputCard);
// Add hover highlighting based on operation type
if (type === 'transpose' && inputCards.length > 0) {
// Transpose: output[i,j] corresponds to input[j,i]
setupTransposeHover(inputCards[0], outputCard);
} else {
// Element-wise activations: output[i,j] corresponds to input[i,j]
const isElementwiseUnary = ['relu', 'sigmoid', 'tanh', 'gelu', 'softmax', 'logsoftmax', 'rmsnorm', 'sqrt'].includes(type);
if (isElementwiseUnary && inputCards.length > 0) {
setupElementwiseHover(inputCards, outputCard);
}
}
container.appendChild(grid);
}
// ==================== Box Colors ====================
function getBoxColors(scheme) {
const c = {
'1': { border: 'rgba(56,189,248,0.9)', bg: 'rgba(15,23,42,0.45)', label: '#7dd3fc' },
'2': { border: 'rgba(34,197,94,0.85)', bg: 'rgba(5,46,22,0.5)', label: '#bbf7d0' },
'3': { border: 'rgba(249,115,22,0.9)', bg: 'rgba(67,20,7,0.55)', label: '#fed7aa' },
'4': { border: 'rgba(129,140,248,0.9)', bg: 'rgba(30,64,175,0.5)', label: '#e0e7ff' },
'5': { border: 'rgba(244,114,182,0.9)', bg: 'rgba(131,24,67,0.5)', label: '#fce7f3' },
'6': { border: 'rgba(248,113,113,0.9)', bg: 'rgba(127,29,29,0.5)', label: '#fee2e2' },
};
return c[scheme] || c['1'];
}
// ==================== LaTeX Rendering ====================
function renderLatexInElement(el) {
if (!window.katex) return;
// Process display math: $$...$$
el.innerHTML = el.innerHTML.replace(/\$\$([\s\S]*?)\$\$/g, (match, tex) => {
try {
return katex.renderToString(tex.trim(), { displayMode: true, throwOnError: false });
} catch (e) {
return match;
}
});
// Process inline math: $...$
el.innerHTML = el.innerHTML.replace(/\$([^\$]+?)\$/g, (match, tex) => {
try {
return katex.renderToString(tex.trim(), { displayMode: false, throwOnError: false });
} catch (e) {
return match;
}
});
}
// ==================== Main Render ====================
function render() {
// Re-render using the first pass layout system
renderFirstPass();
}
function applyZoom() {
const c = document.getElementById('canvas-content');
if (c) c.style.transform = `scale(${zoomLevel})`;
}
// ==================== Drag Behaviors ====================
function attachGroupDrag(group, el, labelEl) {
// Make entire group container draggable
el.onmousedown = function(e) {
// Don't start drag on matrix cards if just clicking
if (e.target.closest('.matrix-card') && !e.target.closest('.group-label')) {
// Allow clicking but not dragging from matrix cards
}
e.preventDefault();
e.stopPropagation();
// If box tool is active, toggle selection for box creation
if (currentTool === 'box') {
if (selectedGroupIds.has(group.id)) {
selectedGroupIds.delete(group.id);
el.classList.remove('group-selected');
} else {
selectedGroupIds.add(group.id);
el.classList.add('group-selected');
}
return;
}
// Select tool: drag the group
selectedGroupIds.clear();
selectedGroupIds.add(group.id);
el.classList.add('group-selected');
const startX = e.clientX, startY = e.clientY;
const origX = group.x, origY = group.y;
function onMove(ev) {
const dx = (ev.clientX - startX) / zoomLevel;
const dy = (ev.clientY - startY) / zoomLevel;
group.x = origX + dx;
group.y = origY + dy;
el.style.left = group.x + 'px';
el.style.top = group.y + 'px';
}
function onUp() {
window.removeEventListener('mousemove', onMove);
window.removeEventListener('mouseup', onUp);
}
window.addEventListener('mousemove', onMove);
window.addEventListener('mouseup', onUp);
};
}
function attachBoxDrag(box, el, labelEl, groupElById) {
labelEl.onmousedown = function(e) {
e.preventDefault();
e.stopPropagation();
selectedBoxIds.clear();
selectedBoxIds.add(box.id);
el.classList.add('layer-box-selected');
const startX = e.clientX, startY = e.clientY;
const boxStart = { x: box.x, y: box.y };
// Collect all groups to move (direct groups + groups from child boxes)
const groupStarts = {};
const childBoxStarts = {};
function collectGroupsFromBox(b) {
(b.groupIds || []).forEach(gid => {
const g = groups.find(gr => gr.id === gid);
if (g && !groupStarts[gid]) groupStarts[gid] = { x: g.x, y: g.y };
});
// Recursively collect from child boxes
(b.childBoxIds || []).forEach(cid => {
const childBox = boxes.find(cb => cb.id === cid);
if (childBox) {
childBoxStarts[cid] = { x: childBox.x, y: childBox.y };
collectGroupsFromBox(childBox);
}
});
}
collectGroupsFromBox(box);
function onMove(ev) {
const dx = (ev.clientX - startX) / zoomLevel;
const dy = (ev.clientY - startY) / zoomLevel;
// Move this box
box.x = boxStart.x + dx;
box.y = boxStart.y + dy;
el.style.left = box.x + 'px';
el.style.top = box.y + 'px';
// Move child boxes
Object.keys(childBoxStarts).forEach(cid => {
const childBox = boxes.find(cb => cb.id === cid);
const childEl = document.querySelector(`[data-box-id="${cid}"]`);
if (childBox) {
childBox.x = childBoxStarts[cid].x + dx;
childBox.y = childBoxStarts[cid].y + dy;
if (childEl) {
childEl.style.left = childBox.x + 'px';
childEl.style.top = childBox.y + 'px';
}
}
});
// Move all groups
Object.keys(groupStarts).forEach(gid => {
const g = groups.find(gr => gr.id === gid);
const gEl = groupElById[gid];
if (g && gEl) {
g.x = groupStarts[gid].x + dx;
g.y = groupStarts[gid].y + dy;
gEl.style.left = g.x + 'px';
gEl.style.top = g.y + 'px';
}
});
}
function onUp() {
window.removeEventListener('mousemove', onMove);
window.removeEventListener('mouseup', onUp);
}
window.addEventListener('mousemove', onMove);
window.addEventListener('mouseup', onUp);
};
// Click to select
el.onclick = e => {
e.stopPropagation();
// In Box mode, toggle selection for nesting
if (currentTool === 'box') {
if (selectedBoxIds.has(box.id)) {
selectedBoxIds.delete(box.id);
el.classList.remove('layer-box-selected');
} else {
selectedBoxIds.add(box.id);
el.classList.add('layer-box-selected');
}
render();
return;
}
// Select mode
selectedBoxIds.clear();
selectedBoxIds.add(box.id);
render();
};
}
function attachNoteDrag(note, el, header) {
header.onmousedown = function(e) {
e.preventDefault();
e.stopPropagation();
const startX = e.clientX, startY = e.clientY;
const origX = note.x, origY = note.y;
function onMove(ev) {
note.x = origX + (ev.clientX - startX) / zoomLevel;
note.y = origY + (ev.clientY - startY) / zoomLevel;
el.style.left = note.x + 'px';
el.style.top = note.y + 'px';
}
function onUp() {
window.removeEventListener('mousemove', onMove);
window.removeEventListener('mouseup', onUp);
}
window.addEventListener('mousemove', onMove);
window.addEventListener('mouseup', onUp);
};
}
// ==================== Notes ====================
let editingNote = null;
function addNote() {
canvasNotes.push({ id: 'n' + Date.now(), x: 100, y: 100, text: 'New note' });
render();
}
function openNoteModal(note) {
editingNote = note;
document.getElementById('note-modal-textarea').value = note.text;
document.getElementById('note-modal-backdrop').classList.add('show');
}
document.getElementById('note-modal-cancel').onclick = () => {
document.getElementById('note-modal-backdrop').classList.remove('show');
};
document.getElementById('note-modal-save').onclick = () => {
if (editingNote) editingNote.text = document.getElementById('note-modal-textarea').value;
document.getElementById('note-modal-backdrop').classList.remove('show');
render();
};
document.getElementById('note-modal-delete').onclick = () => {
if (editingNote) {
const idx = canvasNotes.findIndex(n => n.id === editingNote.id);
if (idx >= 0) canvasNotes.splice(idx, 1);
}
document.getElementById('note-modal-backdrop').classList.remove('show');
render();
};
// ==================== Box Creation & Deletion ====================
function createBoxFromSelection() {
if (selectedGroupIds.size === 0 && selectedBoxIds.size === 0) {
alert('Select groups or boxes first (use Box tool to click on them)');
return;
}
const label = prompt('Box label:', 'Box ' + (boxes.length + 1));
if (!label) return;
const scheme = prompt('Color scheme (1-6):', '1') || '1';
// For parent boxes that wrap other boxes:
// - Only track direct group IDs (not groups from child boxes)
// - Track child box IDs separately
const directGroupIds = Array.from(selectedGroupIds);
const childBoxIds = Array.from(selectedBoxIds);
boxes.push({
id: 'b' + nextBoxId++,
label: label,
scheme: scheme,
groupIds: directGroupIds, // Only directly selected groups
childBoxIds: childBoxIds, // Child boxes (their groups stay with them)
x: 0, y: 0, w: 0, h: 0,
fromCode: false
});
selectedGroupIds.clear();
selectedBoxIds.clear();
currentTool = 'select';
render();
}
function deleteSelected() {
// Delete selected boxes
selectedBoxIds.forEach(bid => {
const idx = boxes.findIndex(b => b.id === bid);
if (idx >= 0) boxes.splice(idx, 1);
});
selectedBoxIds.clear();
// Delete selected groups
selectedGroupIds.forEach(gid => {
const idx = groups.findIndex(g => g.id === gid);
if (idx >= 0) groups.splice(idx, 1);
// Also remove from any boxes
boxes.forEach(box => {
const i = (box.groupIds || []).indexOf(gid);
if (i >= 0) box.groupIds.splice(i, 1);
});
});
selectedGroupIds.clear();
render();
}
// ==================== JSON ====================
function exportJson() {
const state = {
code: document.getElementById('code').value,
groups: groups.map(g => ({
id: g.id,
opType: g.opType,
inputIds: g.inputIds,
outputId: g.outputId,
meta: g.meta,
x: g.x,
y: g.y
})),
boxes: boxes.map(b => ({
id: b.id, label: b.label, scheme: b.scheme,
groupIds: b.groupIds, childBoxIds: b.childBoxIds,
x: b.x, y: b.y, w: b.w, h: b.h,
fromCode: b.fromCode
})),
tensors: Object.fromEntries(
Object.entries(tensors)
.filter(([id, t]) => t != null) // Skip null/undefined tensors
.map(([id, t]) => [id, {
id: t.id || id,
shape: t.shape || [],
data: t.data || [],
name: t.name || null
}])
),
notes: canvasNotes.map(n => ({ id: n.id, x: n.x, y: n.y, text: n.text }))
};
const blob = new Blob([JSON.stringify(state, null, 2)], { type: 'application/json' });
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = 'tinytorch-' + new Date().toISOString().slice(0,10) + '.json';
a.click();
}
document.getElementById('jsonFileInput').onchange = e => {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = ev => {
try {
const state = JSON.parse(ev.target.result);
// Restore code
if (state.code) document.getElementById('code').value = state.code;
// Restore tensors (needed for rendering)
if (state.tensors && typeof state.tensors === 'object') {
// Clear and restore tensors
Object.keys(tensors).forEach(k => delete tensors[k]);
Object.entries(state.tensors).forEach(([id, t]) => {
if (t != null) {
tensors[id] = t;
}
});
}
// Restore groups
if (state.groups && Array.isArray(state.groups)) {
groups.length = 0;
state.groups.forEach(g => groups.push(g));
const groupNums = groups.map(g => parseInt((g.id || '').replace('g', '')) || 0);
nextGroupId = groupNums.length > 0 ? Math.max(...groupNums) + 1 : 1;
}
// Restore boxes
if (state.boxes && Array.isArray(state.boxes)) {
boxes.length = 0;
state.boxes.forEach(b => boxes.push(b));
const boxNums = boxes.map(b => parseInt((b.id || '').replace('b', '')) || 0);
nextBoxId = boxNums.length > 0 ? Math.max(...boxNums) + 1 : 1;
}
// Restore notes
if (state.notes) {
canvasNotes.length = 0;
state.notes.forEach(n => canvasNotes.push(n));
}
// Set flag to preserve loaded positions
skipLayout = true;
// If tensors weren't in the JSON (old format), warn user
if (!state.tensors || Object.keys(state.tensors).length === 0) {
alert('This JSON was saved in an old format without tensor data.\n\n' +
'The layout will be restored, but matrices will be empty.\n' +
'To get full data: Run code, then Export JSON again.');
}
// Reset zoom to 100% for proper positioning (set before render so toolbar picks it up)
zoomLevel = 1;
// Render without running code (data already loaded or using placeholders)
render();
// After render, ensure zoom select shows correct value and scroll to top
setTimeout(() => {
const zoomSelect = document.querySelector('#toolbar select');
if (zoomSelect) zoomSelect.value = '100%';
applyZoom();
const canvas = document.getElementById('canvas');
if (canvas) {
canvas.scrollTop = 0;
canvas.scrollLeft = 0;
}
}, 50);
} catch(err) {
console.error('JSON load error:', err);
alert('Invalid JSON: ' + err.message);
}
// Reset file input so same file can be loaded again
e.target.value = '';
};
reader.readAsText(file);
};
// ==================== Splitters ====================
document.querySelectorAll('.v-resizer').forEach(handle => {
const leftEl = document.getElementById(handle.dataset.left);
if (!leftEl) return;
handle.onmousedown = e => {
e.preventDefault();
const startX = e.clientX;
const startW = leftEl.getBoundingClientRect().width;
function onMove(ev) { leftEl.style.flex = `0 0 ${Math.max(260, startW + ev.clientX - startX)}px`; }
function onUp() { window.removeEventListener('mousemove', onMove); window.removeEventListener('mouseup', onUp); }
window.addEventListener('mousemove', onMove);
window.addEventListener('mouseup', onUp);
};
});
// ==================== Run Code ====================
function runCode() {
if (!wsConnected) { document.getElementById('error').textContent = 'Not connected'; return; }
ws.send(JSON.stringify({ action: 'run', code: document.getElementById('code').value }));
}
document.getElementById('run').onclick = runCode;
document.getElementById('code').onkeydown = e => {
if ((e.ctrlKey || e.metaKey) && e.key === 'Enter') { e.preventDefault(); runCode(); }
};
document.getElementById('console-clear').onclick = clearConsole;
// ==================== Console Resizer ====================
(function() {
const resizer = document.getElementById('console-resizer');
const consoleContainer = document.getElementById('console-container');
let startY, startHeight;
resizer.addEventListener('mousedown', (e) => {
e.preventDefault();
startY = e.clientY;
startHeight = consoleContainer.offsetHeight;
document.addEventListener('mousemove', onMouseMove);
document.addEventListener('mouseup', onMouseUp);
document.body.style.cursor = 'ns-resize';
document.body.style.userSelect = 'none';
});
function onMouseMove(e) {
// Dragging up increases height (startY - e.clientY is positive when moving up)
const deltaY = startY - e.clientY;
const newHeight = Math.max(60, Math.min(400, startHeight + deltaY));
consoleContainer.style.height = newHeight + 'px';
}
function onMouseUp() {
document.removeEventListener('mousemove', onMouseMove);
document.removeEventListener('mouseup', onMouseUp);
document.body.style.cursor = '';
document.body.style.userSelect = '';
}
})();
// ==================== Init ====================
// Restore retro mode preference
if (localStorage.getItem('retroMode') === '1') {
document.body.classList.add('retro-mode');
}
document.getElementById('run').disabled = true;
connectWebSocket();
// ==================== Automatic PDF Loading ====================
// ==================== Automatic PDF Loading ====================
async function autoLoadMathPDF() {
// Wait for PDF library
if (!window.pdfjsLib) {
setTimeout(autoLoadMathPDF, 200);
}
try {
// FIX: Fetch from '/static/math.pdf' because the server mounts it there
// We try both paths just in case you open the file directly or via server
let response = await fetch('static/math.pdf');
// Fallback: If that fails (e.g. 404), try the root (for file:// protocol or different configs)
if (!response.ok) {
console.log("...trying fallback path...");
response = await fetch('math.pdf');
}
if (!response.ok) throw new Error(`HTTP error! status: ${response.status}`);
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const loadingTask = pdfjsLib.getDocument(url);
loadingTask.promise.then(doc => {
pdfDoc = doc;
pdfPageNum = 1;
pdfZoomFactor = 1.0;
renderPdfPage(1);
console.log('math.pdf loaded successfully!');
});
} catch (err) {
console.warn('PDF Load Error:', err);
document.getElementById('pdfPlaceholder').innerHTML =
"Could not load PDF. Ensure <b>math.pdf</b> is in the <b>static</b> folder<br>" +
"and you are running <b>http://localhost:8000</b>";
}
}
autoLoadMathPDF();
</script>
</body>
</html>