priyadip
Fix: use <details>/<summary> for native toggle, no JS needed for cards
4873c62
"""
app.py
Gradio Space: Interactive Transformer Visualizer — English → Bengali
"""
import gradio as gr
import torch
import json
import os
import numpy as np
from pathlib import Path
from transformer import Transformer, CalcLog
from training import build_model, run_training, visualize_training_step, collate_batch
from inference import visualize_inference
from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX
# ─────────────────────────────────────────────
# Global state
# ─────────────────────────────────────────────
DEVICE = "cpu"
src_v, tgt_v = get_vocabs()
MODEL: Transformer = None
LOSS_HISTORY = []
IS_TRAINED = False
def get_or_init_model():
global MODEL
if MODEL is None:
MODEL = build_model(len(src_v), len(tgt_v), DEVICE)
return MODEL
# ─────────────────────────────────────────────
# HTML renderer for calc log
# ─────────────────────────────────────────────
def render_matrix_html(val, max_rows=6, max_cols=8):
"""Convert a nested list / scalar to an HTML matrix table."""
if isinstance(val, (int, float)):
return f'<span class="scalar-val">{val:.5f}</span>'
if isinstance(val, dict):
rows = "".join(
f'<tr><td class="dict-key">{k}</td><td class="dict-val">{v}</td></tr>'
for k, v in val.items()
)
return f'<table class="dict-table">{rows}</table>'
if isinstance(val, list):
# 0-D or scalar list
if len(val) == 0:
return "<em>empty</em>"
# 1-D
if not isinstance(val[0], list):
clipped = val[:max_cols*2]
cells = "".join(
f'<td class="mat-cell">{v:.4f}</td>'
if isinstance(v, float) else f'<td class="mat-cell">{v}</td>'
for v in clipped
)
suffix = f'<td class="mat-more">…+{len(val)-len(clipped)}</td>' if len(val) > len(clipped) else ""
return f'<table class="matrix-1d"><tr>{cells}{suffix}</tr></table>'
# 2-D
rows_html = ""
display_rows = val[:max_rows]
for row in display_rows:
display_cols = row[:max_cols]
cells = "".join(
f'<td class="mat-cell" style="--v:{min(max(float(c),-1),1):.3f}">'
f'{float(c):.3f}</td>'
if isinstance(c, (int, float)) else f'<td class="mat-cell">{c}</td>'
for c in display_cols
)
suffix = f'<td class="mat-more">…</td>' if len(row) > max_cols else ""
rows_html += f"<tr>{cells}{suffix}</tr>"
if len(val) > max_rows:
rows_html += f'<tr><td colspan="{max_cols+1}" class="mat-more">…{len(val)-max_rows} more rows</td></tr>'
return f'<table class="matrix-2d">{rows_html}</table>'
return f'<code>{str(val)[:200]}</code>'
def calc_log_to_html(steps):
"""Turn CalcLog steps into rich HTML accordion."""
if not steps:
return "<p style='color:#888'>No calculation log yet.</p>"
cards = []
for i, step in enumerate(steps):
name = step.get("name", f"step_{i}")
formula = step.get("formula", "")
note = step.get("note", "")
shape = step.get("shape")
val = step.get("value")
shape_badge = f'<span class="shape-badge">{shape}</span>' if shape else ""
formula_html = f'<div class="formula">⟨ {formula} ⟩</div>' if formula else ""
note_html = f'<div class="step-note">ℹ {note}</div>' if note else ""
matrix_html = render_matrix_html(val) if val is not None else ""
# Color category by name prefix
cat = "default"
n = name.upper()
if "EMBED" in n or "TOKEN" in n: cat = "embed"
elif "PE" in n or "POSITIONAL" in n: cat = "pe"
elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn"
elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn"
elif "NORM" in n or "RESIDUAL" in n: cat = "norm"
elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss"
elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer"
elif "CROSS" in n: cat = "cross"
elif "MASK" in n: cat = "mask"
cards.append(f"""
<details class="calc-card cat-{cat}" data-idx="{i}">
<summary class="calc-header">
<span class="step-num">#{i+1}</span>
<span class="step-name cat-label-{cat}">{name.replace('_',' ')}</span>
{shape_badge}
<span class="toggle-arrow">▶</span>
</summary>
<div class="calc-body">
{formula_html}
{note_html}
<div class="matrix-wrap">{matrix_html}</div>
</div>
</details>""")
return "\n".join(cards)
# ─────────────────────────────────────────────
# Attention heatmap HTML
# ─────────────────────────────────────────────
def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"):
"""weights: 2D list [tgt, src]"""
if not weights:
return ""
rows_html = ""
for i, row in enumerate(weights):
cells = ""
for j, w in enumerate(row):
alpha = min(float(w), 1.0)
cells += f'<td class="heat-cell" style="--a:{alpha:.3f}" title="{row_labels[i] if i<len(row_labels) else i}{col_labels[j] if j<len(col_labels) else j}: {alpha:.3f}">{alpha:.2f}</td>'
lbl = row_labels[i] if i < len(row_labels) else str(i)
rows_html += f'<tr><td class="heat-label">{lbl}</td>{cells}</tr>'
header = '<tr><td></td>' + "".join(f'<td class="heat-col-label">{c}</td>' for c in col_labels) + '</tr>'
return f"""
<div class="heatmap-container">
<div class="heatmap-title">{title}</div>
<table class="heatmap">{header}{rows_html}</table>
</div>"""
# ─────────────────────────────────────────────
# Decoding steps HTML
# ─────────────────────────────────────────────
def decode_steps_html(step_logs, src_tokens):
if not step_logs:
return ""
html = '<div class="decode-steps"><div class="decode-title">🔁 Auto-regressive Decoding Steps</div>'
for s in step_logs:
step = s.get("step", 0)
tokens_so_far = s.get("tokens_so_far", [])
top5 = s.get("top5", [])
chosen = s.get("chosen_token", "?")
prob = s.get("chosen_prob", 0)
bars = ""
if top5:
max_p = max(t["prob"] for t in top5) or 1
for t in top5:
pct = t["prob"] / max_p * 100
is_chosen = "chosen" if t["token"] == chosen else ""
bars += f"""<div class="bar-row {is_chosen}">
<span class="bar-label">{t['token']}</span>
<div class="bar" style="width:{pct:.1f}%"></div>
<span class="bar-prob">{t['prob']:.3f}</span>
</div>"""
cross_heat = ""
if s.get("cross_attn") and src_tokens:
attn_mat = s["cross_attn"] # [num_heads][T_q][T_src]
if attn_mat and attn_mat[0]:
# Take head-0, last decoded position → [T_src] floats
last_pos_attn = attn_mat[0][-1] # [T_src]
last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap
cross_heat = attention_heatmap_html(
last_row, [chosen], src_tokens,
title=f"Cross-Attn: '{chosen}' → English"
)
html += f"""
<div class="decode-step">
<div class="decode-step-header">
<span class="step-badge">Step {step+1}</span>
<span class="step-ctx">Context: {' '.join(tokens_so_far)}</span>
<span class="step-arrow">→</span>
<span class="step-chosen">'{chosen}'</span>
<span class="step-prob">{prob:.3f}</span>
</div>
<div class="step-bars">{bars}</div>
{cross_heat}
</div>"""
html += "</div>"
return html
# ─────────────────────────────────────────────
# Architecture SVG
# ─────────────────────────────────────────────
ARCH_SVG = """
<div id="arch-diagram">
<svg viewBox="0 0 820 900" xmlns="http://www.w3.org/2000/svg" style="width:100%;max-width:820px;margin:auto;display:block">
<defs>
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
<path d="M0,0 L0,6 L8,3 z" fill="#64ffda"/>
</marker>
<filter id="glow">
<feGaussianBlur stdDeviation="2" result="blur"/>
<feMerge><feMergeNode in="blur"/><feMergeNode in="SourceGraphic"/></feMerge>
</filter>
</defs>
<!-- Background -->
<rect width="820" height="900" fill="#0a0f1e" rx="12"/>
<!-- Title -->
<text x="410" y="35" text-anchor="middle" fill="#64ffda" font-size="16" font-family="monospace" font-weight="bold">Transformer Architecture — English → Bengali</text>
<!-- ── ENCODER (left) ── -->
<rect x="40" y="60" width="330" height="720" rx="10" fill="#0d1b2a" stroke="#1e4d6b" stroke-width="1.5"/>
<text x="205" y="90" text-anchor="middle" fill="#4fc3f7" font-size="13" font-weight="bold">ENCODER</text>
<!-- Input Embedding -->
<rect x="70" y="110" width="270" height="40" rx="6" fill="#1a3a5c" stroke="#4fc3f7" stroke-width="1.5"/>
<text x="205" y="135" text-anchor="middle" fill="#e0f7fa" font-size="11">Input Embedding + Positional Encoding</text>
<!-- Encoder Layer Box -->
<rect x="60" y="175" width="290" height="340" rx="8" fill="#112233" stroke="#1e4d6b" stroke-width="1" stroke-dasharray="4"/>
<text x="100" y="198" fill="#607d8b" font-size="10">Encoder Layer × N</text>
<!-- Multi-Head Self-Attention -->
<rect x="80" y="210" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
<text x="205" y="232" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Multi-Head Self-Attention</text>
<text x="205" y="248" text-anchor="middle" fill="#80deea" font-size="9">Q = K = V = encoder input</text>
<!-- Add & Norm 1 -->
<rect x="80" y="278" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
<text x="205" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
<!-- FFN -->
<rect x="80" y="328" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
<text x="205" y="350" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Feed-Forward Network</text>
<text x="205" y="366" text-anchor="middle" fill="#80deea" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
<!-- Add & Norm 2 -->
<rect x="80" y="396" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
<text x="205" y="416" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
<!-- Encoder output arrow down -->
<line x1="205" y1="455" x2="205" y2="550" stroke="#64ffda" stroke-width="1.5" marker-end="url(#arr)"/>
<text x="215" y="510" fill="#64ffda" font-size="9">K, V to</text>
<text x="215" y="522" fill="#64ffda" font-size="9">decoder</text>
<!-- Encoder output box -->
<rect x="70" y="555" width="270" height="40" rx="6" fill="#0d2b1a" stroke="#00e676" stroke-width="1.5"/>
<text x="205" y="580" text-anchor="middle" fill="#a5d6a7" font-size="11">Encoder Output (K, V)</text>
<!-- ── DECODER (right) ── -->
<rect x="450" y="60" width="330" height="720" rx="10" fill="#1a0d2a" stroke="#4a1b6b" stroke-width="1.5"/>
<text x="615" y="90" text-anchor="middle" fill="#ce93d8" font-size="13" font-weight="bold">DECODER</text>
<!-- Target Embedding -->
<rect x="480" y="110" width="270" height="40" rx="6" fill="#3a1a5c" stroke="#ce93d8" stroke-width="1.5"/>
<text x="615" y="135" text-anchor="middle" fill="#f3e5f5" font-size="11">Target Embedding + Positional Encoding</text>
<!-- Decoder Layer Box -->
<rect x="470" y="175" width="290" height="460" rx="8" fill="#1a1133" stroke="#4a1b6b" stroke-width="1" stroke-dasharray="4"/>
<text x="510" y="198" fill="#607d8b" font-size="10">Decoder Layer × N</text>
<!-- Masked MHA -->
<rect x="490" y="210" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
<text x="615" y="232" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Masked Multi-Head Self-Attention</text>
<text x="615" y="248" text-anchor="middle" fill="#ce93d8" font-size="9">Q = K = V = decoder input (causal mask)</text>
<!-- Add & Norm D1 -->
<rect x="490" y="278" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
<text x="615" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
<!-- Cross-Attention -->
<rect x="490" y="328" width="250" height="60" rx="6" fill="#1b2b4b" stroke="#29b6f6" stroke-width="2" filter="url(#glow)"/>
<text x="615" y="350" text-anchor="middle" fill="#e1f5fe" font-size="11" font-weight="bold">Cross-Attention</text>
<text x="615" y="366" text-anchor="middle" fill="#81d4fa" font-size="9">Q = decoder | K, V = encoder</text>
<text x="615" y="380" text-anchor="middle" fill="#29b6f6" font-size="9" font-weight="bold">← KEY CONNECTION</text>
<!-- Add & Norm D2 -->
<rect x="490" y="408" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
<text x="615" y="428" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
<!-- FFN Decoder -->
<rect x="490" y="458" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
<text x="615" y="480" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Feed-Forward Network</text>
<text x="615" y="496" text-anchor="middle" fill="#ce93d8" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
<!-- Add & Norm D3 -->
<rect x="490" y="526" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
<text x="615" y="546" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
<!-- Output Linear + Softmax -->
<rect x="480" y="600" width="270" height="40" rx="6" fill="#2b1b0a" stroke="#ffb300" stroke-width="1.5"/>
<text x="615" y="625" text-anchor="middle" fill="#fff8e1" font-size="11">Linear + Softmax → Bengali Token</text>
<!-- Cross-attention arrow from encoder to decoder -->
<path d="M340,590 Q410,480 490,368" stroke="#29b6f6" stroke-width="2" fill="none"
stroke-dasharray="6,3" marker-end="url(#arr)"/>
<text x="390" y="500" fill="#29b6f6" font-size="9" transform="rotate(-50,390,500)">K, V flow</text>
<!-- Input arrow -->
<line x1="205" y1="840" x2="205" y2="780" stroke="#4fc3f7" stroke-width="1.5" marker-end="url(#arr)"/>
<text x="205" y="858" text-anchor="middle" fill="#4fc3f7" font-size="11">English Input</text>
<line x1="615" y1="840" x2="615" y2="660" stroke="#ce93d8" stroke-width="1.5" marker-end="url(#arr)"/>
<text x="615" y="858" text-anchor="middle" fill="#ce93d8" font-size="11">Bengali Output</text>
<!-- Legend -->
<rect x="60" y="870" width="700" height="20" rx="4" fill="#0a1520" stroke="#1e2d3d" stroke-width="1"/>
<circle cx="80" cy="880" r="4" fill="#26c6da"/><text x="88" y="884" fill="#80deea" font-size="8">Self-Attention</text>
<circle cx="160" cy="880" r="4" fill="#29b6f6"/><text x="168" y="884" fill="#81d4fa" font-size="8">Cross-Attention</text>
<circle cx="250" cy="880" r="4" fill="#ab47bc"/><text x="258" y="884" fill="#ce93d8" font-size="8">Masked Attn</text>
<circle cx="350" cy="880" r="4" fill="#00e676"/><text x="358" y="884" fill="#a5d6a7" font-size="8">Enc→Dec K,V</text>
<circle cx="450" cy="880" r="4" fill="#ffb300"/><text x="458" y="884" fill="#fff8e1" font-size="8">Output Layer</text>
</svg>
</div>
"""
# ─────────────────────────────────────────────
# CSS + JS
# ─────────────────────────────────────────────
CUSTOM_CSS = """
/* ── fonts ── */
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap');
:root {
--bg: #07090f;
--bg2: #0d1120;
--bg3: #111827;
--card: #141c2e;
--border: #1e2d45;
--accent: #64ffda;
--accent2: #29b6f6;
--accent3: #ce93d8;
--accent4: #ffb300;
--text: #e2e8f0;
--muted: #64748b;
--embed: #4fc3f7;
--pe: #26c6da;
--attn: #f06292;
--ffn: #aed581;
--norm: #90a4ae;
--loss: #ef9a9a;
--infer: #80cbc4;
--cross: #29b6f6;
--mask: #ffb300;
}
body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; }
h1, h2, h3 { font-family: 'Syne', sans-serif !important; }
/* ── tabs ── */
.tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; }
.tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); }
/* ── inputs ── */
input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; }
input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; }
button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; }
button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; }
/* ── calc cards ── */
details.calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; }
details.calc-card > summary { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; list-style: none; transition: background 0.15s; }
details.calc-card > summary::-webkit-details-marker { display: none; }
details.calc-card > summary::marker { display: none; }
details.calc-card > summary:hover { background: rgba(255,255,255,0.03); }
details.calc-card > .calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); }
.step-num { color: var(--muted); font-size: 11px; min-width: 28px; }
.step-name { font-weight: 600; font-size: 12px; flex: 1; }
.toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; }
details[open] .toggle-arrow { transform: rotate(90deg); }
.shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); }
.formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); }
.step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; }
/* category colors */
.cat-label-embed { color: var(--embed); }
.cat-label-pe { color: var(--pe); }
.cat-label-attn { color: var(--attn); }
.cat-label-ffn { color: var(--ffn); }
.cat-label-norm { color: var(--norm); }
.cat-label-loss { color: var(--loss); }
.cat-label-infer { color: var(--infer); }
.cat-label-cross { color: var(--cross); }
.cat-label-mask { color: var(--mask); }
.cat-label-default{ color: var(--text); }
.cat-embed { border-left: 3px solid var(--embed); }
.cat-pe { border-left: 3px solid var(--pe); }
.cat-attn { border-left: 3px solid var(--attn); }
.cat-ffn { border-left: 3px solid var(--ffn); }
.cat-norm { border-left: 3px solid var(--norm); }
.cat-loss { border-left: 3px solid var(--loss); }
.cat-infer { border-left: 3px solid var(--infer); }
.cat-cross { border-left: 3px solid var(--cross); }
.cat-mask { border-left: 3px solid var(--mask); }
.cat-default{ border-left: 3px solid var(--border); }
/* ── matrix tables ── */
.matrix-wrap { overflow-x: auto; }
.matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; }
.mat-cell {
padding: 2px 5px; text-align: right; min-width: 48px;
background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%));
color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05);
}
.mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; }
.dict-table { font-size: 11px; width: 100%; }
.dict-key { color: var(--accent); padding: 2px 8px 2px 0; }
.dict-val { color: var(--text); padding: 2px; }
.scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; }
/* ── heatmap ── */
.heatmap-container { margin: 8px 0; }
.heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; }
.heatmap { border-collapse: collapse; font-size: 10px; }
.heat-cell {
width: 36px; height: 24px; text-align: center;
background: rgba(41, 182, 246, calc(var(--a, 0)));
border: 1px solid rgba(255,255,255,0.04);
color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%));
font-size: 9px; cursor: default;
}
.heat-cell:hover { outline: 1px solid var(--accent); }
.heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; }
.heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; }
/* ── decode steps ── */
.decode-steps { margin-top: 12px; }
.decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); }
.decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); }
.decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; }
.step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; }
.step-ctx { color: var(--muted); font-size: 11px; }
.step-arrow { color: var(--accent4); }
.step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; }
.step-prob { color: var(--accent4); font-size: 11px; }
.step-bars { margin: 4px 0; }
.bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; }
.bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; }
.bar-row.chosen .bar { background: var(--accent3) !important; }
.bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
.bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; }
.bar-prob { font-size: 10px; color: var(--muted); }
/* ── loss chart ── */
#loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; }
/* ── arch diagram ── */
#arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; }
/* ── result banner ── */
.result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; }
.result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; }
.result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; }
.result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; }
/* ── misc ── */
.gradio-html { background: transparent !important; }
.panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; }
.log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; }
"""
CUSTOM_JS = """
// <details> elements toggle natively — JS only needed for expand-all/collapse-all/filter
window._expandAll = function() {
document.querySelectorAll('details.calc-card').forEach(d => d.open = true);
};
window._collapseAll = function() {
document.querySelectorAll('details.calc-card').forEach(d => d.open = false);
};
window._filterCards = function(cat) {
document.querySelectorAll('details.calc-card').forEach(d => {
d.style.display = (!cat || d.classList.contains('cat-'+cat)) ? '' : 'none';
});
};
// Toolbar button delegation
document.addEventListener('click', function(e) {
const btn = e.target.closest('[data-ga]');
if (!btn) return;
const a = btn.dataset.ga;
if (a === 'expand') window._expandAll();
else if (a === 'collapse') window._collapseAll();
else if (a.startsWith('filter:')) window._filterCards(a.slice(7));
}, true);
"""
# ─────────────────────────────────────────────
# Pure-SVG loss curve (no JS/canvas needed)
# ─────────────────────────────────────────────
def _loss_svg(losses):
if not losses:
return ""
W, H = 580, 200
pl, pr, pt, pb = 52, 16, 16, 36
pw, ph = W - pl - pr, H - pt - pb
mn, mx = min(losses), max(losses)
rng = mx - mn or 1
n = len(losses)
def px(i): return pl + (i / max(n - 1, 1)) * pw
def py(v): return pt + ph - ((v - mn) / rng) * ph
# Grid + Y labels
grid = ""
for k in range(5):
v = mn + (k / 4) * rng
y = py(v)
grid += f'<line x1="{pl}" y1="{y:.1f}" x2="{pl+pw}" y2="{y:.1f}" stroke="#1e2d45" stroke-width="0.5"/>'
grid += f'<text x="{pl-4}" y="{y+4:.1f}" text-anchor="end" fill="#64748b" font-size="9" font-family="monospace">{v:.3f}</text>'
# Polyline points
pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses))
fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}"
# X labels
xlabels = ""
for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)):
xlabels += f'<text x="{px(idx):.1f}" y="{H-4}" text-anchor="middle" fill="#64748b" font-size="9" font-family="monospace">E{idx+1}</text>'
return f"""
<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:8px;padding:12px;margin-top:8px">
<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">📉 Training Loss Curve</div>
<svg width="{W}" height="{H}" style="display:block;max-width:100%">
<defs>
<linearGradient id="lcg" x1="0" y1="0" x2="{W}" y2="0" gradientUnits="userSpaceOnUse">
<stop offset="0%" stop-color="#64ffda"/><stop offset="100%" stop-color="#29b6f6"/>
</linearGradient>
</defs>
<rect width="{W}" height="{H}" fill="#0d1120"/>
{grid}
<polygon points="{fill_pts}" fill="rgba(100,255,218,0.08)"/>
<polyline points="{pts}" fill="none" stroke="url(#lcg)" stroke-width="2.5" stroke-linejoin="round"/>
{xlabels}
</svg>
</div>"""
# ─────────────────────────────────────────────
# Gradio callbacks
# ─────────────────────────────────────────────
def do_train(epochs_str, progress=gr.Progress()):
global MODEL, LOSS_HISTORY, IS_TRAINED
try:
epochs = int(epochs_str)
except:
epochs = 30
losses = []
def cb(ep, total, loss):
losses.append(loss)
progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}")
MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb)
IS_TRAINED = True
chart_html = _loss_svg(LOSS_HISTORY)
return (
f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}",
chart_html
)
def do_training_viz(en_sentence, bn_sentence):
model = get_or_init_model()
if not en_sentence.strip():
return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
if not bn_sentence.strip():
bn_sentence = "আমি তোমাকে ভালোবাসি"
result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE)
# Attention heatmap (cross-attn layer 0, head 0)
meta = result.get("meta", {})
attn_html = ""
src_tokens = result.get("src_tokens", [])
tgt_tokens = result.get("tgt_tokens", [])
result_banner = f"""
<div class="result-banner">
<div class="result-label">English Input</div>
<div class="result-en">"{en_sentence}"</div>
<div class="result-label" style="margin-top:8px">Bengali (Teacher-forced)</div>
<div class="result-bn">{bn_sentence}</div>
<div style="margin-top:8px;color:var(--loss);font-size:13px">
📉 Loss: <strong>{result['loss']:.4f}</strong>
</div>
</div>"""
calc_html = f"""
<div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
<button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
<button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
{"".join(f'<button data-ga="filter:{cat}" style="background:var(--card);color:var(--cat-{cat},var(--text));border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">{cat}</button>' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])}
<button data-ga="filter:" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">show all</button>
</div>
<div class="log-container">
{calc_log_to_html(result.get('calc_log', []))}
</div>"""
return result_banner, calc_html, attn_html
def do_inference_viz(en_sentence, decode_method):
model = get_or_init_model()
if not en_sentence.strip():
return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method)
result_banner = f"""
<div class="result-banner">
<div class="result-label">English Input</div>
<div class="result-en">"{en_sentence}"</div>
<div class="result-label" style="margin-top:8px">Bengali Translation ({decode_method})</div>
<div class="result-bn">{result['translation'] or '(no output)'}</div>
<div style="margin-top:6px;color:var(--muted);font-size:11px">
Tokens: {' → '.join(result['output_tokens'])}
</div>
</div>"""
decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", []))
calc_html = f"""
<div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
<button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
<button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
</div>
<div class="log-container">
{calc_log_to_html(result.get('calc_log', []))}
</div>"""
return result_banner + decode_html, calc_html, ""
# ─────────────────────────────────────────────
# Build UI
# ─────────────────────────────────────────────
def build_ui():
_theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate")
with gr.Blocks(
title="Transformer Visualizer — EN→BN",
css=CUSTOM_CSS,
js=f"() => {{ {CUSTOM_JS} }}",
theme=_theme,
) as demo:
gr.HTML("""
<div style="text-align:center;padding:24px 0 12px;border-bottom:1px solid #1e2d45;margin-bottom:16px">
<div style="font-family:'Syne',sans-serif;font-size:28px;font-weight:800;
background:linear-gradient(135deg,#64ffda,#29b6f6,#ce93d8);
-webkit-background-clip:text;-webkit-text-fill-color:transparent;letter-spacing:2px">
TRANSFORMER VISUALIZER
</div>
<div style="color:#64748b;font-size:12px;letter-spacing:3px;margin-top:4px;font-family:'JetBrains Mono',monospace">
ENGLISH → BENGALI · EVERY CALCULATION EXPOSED
</div>
</div>
""")
with gr.Tabs():
# ── TAB 0: Architecture ──────────────────
with gr.Tab("🏗️ Architecture"):
gr.HTML(ARCH_SVG)
gr.HTML("""
<div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;margin-top:12px">
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
<div style="color:#4fc3f7;font-weight:700;margin-bottom:8px">📌 Encoder Flow</div>
<div style="color:#94a3b8;font-size:12px;line-height:1.8">
1. English tokens → Embedding (d_model=64)<br>
2. + Positional Encoding (sin/cos)<br>
3. Multi-Head Self-Attention (4 heads)<br>
4. Add &amp; LayerNorm<br>
5. Feed-Forward (64→128→64)<br>
6. Add &amp; LayerNorm<br>
7. Repeat × 2 layers<br>
8. Output K, V for decoder
</div>
</div>
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
<div style="color:#ce93d8;font-weight:700;margin-bottom:8px">📌 Decoder Flow</div>
<div style="color:#94a3b8;font-size:12px;line-height:1.8">
1. Bengali tokens → Embedding<br>
2. + Positional Encoding<br>
3. Masked MHA (future tokens blocked)<br>
4. Add &amp; LayerNorm<br>
5. Cross-Attention: Q←decoder, K,V←encoder<br>
6. Add &amp; LayerNorm<br>
7. Feed-Forward<br>
8. Linear → Softmax → Bengali token
</div>
</div>
</div>
""")
# ── TAB 1: Train ─────────────────────────
with gr.Tab("🏋️ Train Model"):
with gr.Row():
with gr.Column(scale=1):
gr.HTML('<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">Quick Train</div>')
epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1)
train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary")
train_status = gr.HTML()
with gr.Column(scale=2):
loss_chart = gr.HTML()
train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart])
# ── TAB 2: Training Step Viz ──────────────
with gr.Tab("🔬 Training Step"):
gr.HTML('<div style="color:#ef9a9a;font-size:12px;margin-bottom:12px">📚 Shows <strong>teacher-forcing</strong>: ground-truth Bengali tokens are fed to decoder, loss + gradients computed.</div>')
with gr.Row():
en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি")
run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary")
result_html_t = gr.HTML()
with gr.Row():
with gr.Column(scale=2):
calc_html_t = gr.HTML()
with gr.Column(scale=1):
attn_html_t = gr.HTML()
run_train_viz.click(do_training_viz,
inputs=[en_in_t, bn_in_t],
outputs=[result_html_t, calc_html_t, attn_html_t])
# ── TAB 3: Inference Viz ──────────────────
with gr.Tab("⚡ Inference"):
gr.HTML('<div style="color:#80cbc4;font-size:12px;margin-bottom:12px">🤖 Shows <strong>auto-regressive decoding</strong>: model generates Bengali token by token, no ground truth needed.</div>')
with gr.Row():
en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method")
run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary")
result_html_i = gr.HTML()
with gr.Row():
with gr.Column(scale=2):
calc_html_i = gr.HTML()
with gr.Column(scale=1):
attn_html_i = gr.HTML()
run_infer.click(do_inference_viz,
inputs=[en_in_i, decode_radio],
outputs=[result_html_i, calc_html_i, attn_html_i])
# ── TAB 4: Examples ──────────────────────
with gr.Tab("📖 Examples"):
gr.HTML("""
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:16px">
<div style="color:#64ffda;font-weight:700;margin-bottom:12px">Try these sentences:</div>
<div style="display:grid;grid-template-columns:1fr 1fr;gap:8px">
""" + "".join(
f'<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:6px;padding:8px">'
f'<div style="color:#4fc3f7;font-size:12px">{en}</div>'
f'<div style="color:#ce93d8;font-size:13px;font-weight:600">{bn}</div>'
f'</div>'
for en, bn in PARALLEL_DATA[:12]
) + "</div></div>")
return demo
demo = build_ui()
demo.launch(server_name="0.0.0.0")