Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| Gradio Space: Interactive Transformer Visualizer — English → Bengali | |
| """ | |
| import gradio as gr | |
| import torch | |
| import json | |
| import os | |
| import numpy as np | |
| from pathlib import Path | |
| from transformer import Transformer, CalcLog | |
| from training import build_model, run_training, visualize_training_step, collate_batch | |
| from inference import visualize_inference | |
| from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX | |
| # ───────────────────────────────────────────── | |
| # Global state | |
| # ───────────────────────────────────────────── | |
| DEVICE = "cpu" | |
| src_v, tgt_v = get_vocabs() | |
| MODEL: Transformer = None | |
| LOSS_HISTORY = [] | |
| IS_TRAINED = False | |
| def get_or_init_model(): | |
| global MODEL | |
| if MODEL is None: | |
| MODEL = build_model(len(src_v), len(tgt_v), DEVICE) | |
| return MODEL | |
| # ───────────────────────────────────────────── | |
| # HTML renderer for calc log | |
| # ───────────────────────────────────────────── | |
| def render_matrix_html(val, max_rows=6, max_cols=8): | |
| """Convert a nested list / scalar to an HTML matrix table.""" | |
| if isinstance(val, (int, float)): | |
| return f'<span class="scalar-val">{val:.5f}</span>' | |
| if isinstance(val, dict): | |
| rows = "".join( | |
| f'<tr><td class="dict-key">{k}</td><td class="dict-val">{v}</td></tr>' | |
| for k, v in val.items() | |
| ) | |
| return f'<table class="dict-table">{rows}</table>' | |
| if isinstance(val, list): | |
| # 0-D or scalar list | |
| if len(val) == 0: | |
| return "<em>empty</em>" | |
| # 1-D | |
| if not isinstance(val[0], list): | |
| clipped = val[:max_cols*2] | |
| cells = "".join( | |
| f'<td class="mat-cell">{v:.4f}</td>' | |
| if isinstance(v, float) else f'<td class="mat-cell">{v}</td>' | |
| for v in clipped | |
| ) | |
| suffix = f'<td class="mat-more">…+{len(val)-len(clipped)}</td>' if len(val) > len(clipped) else "" | |
| return f'<table class="matrix-1d"><tr>{cells}{suffix}</tr></table>' | |
| # 2-D | |
| rows_html = "" | |
| display_rows = val[:max_rows] | |
| for row in display_rows: | |
| display_cols = row[:max_cols] | |
| cells = "".join( | |
| f'<td class="mat-cell" style="--v:{min(max(float(c),-1),1):.3f}">' | |
| f'{float(c):.3f}</td>' | |
| if isinstance(c, (int, float)) else f'<td class="mat-cell">{c}</td>' | |
| for c in display_cols | |
| ) | |
| suffix = f'<td class="mat-more">…</td>' if len(row) > max_cols else "" | |
| rows_html += f"<tr>{cells}{suffix}</tr>" | |
| if len(val) > max_rows: | |
| rows_html += f'<tr><td colspan="{max_cols+1}" class="mat-more">…{len(val)-max_rows} more rows</td></tr>' | |
| return f'<table class="matrix-2d">{rows_html}</table>' | |
| return f'<code>{str(val)[:200]}</code>' | |
| def calc_log_to_html(steps): | |
| """Turn CalcLog steps into rich HTML accordion.""" | |
| if not steps: | |
| return "<p style='color:#888'>No calculation log yet.</p>" | |
| cards = [] | |
| for i, step in enumerate(steps): | |
| name = step.get("name", f"step_{i}") | |
| formula = step.get("formula", "") | |
| note = step.get("note", "") | |
| shape = step.get("shape") | |
| val = step.get("value") | |
| shape_badge = f'<span class="shape-badge">{shape}</span>' if shape else "" | |
| formula_html = f'<div class="formula">⟨ {formula} ⟩</div>' if formula else "" | |
| note_html = f'<div class="step-note">ℹ {note}</div>' if note else "" | |
| matrix_html = render_matrix_html(val) if val is not None else "" | |
| # Color category by name prefix | |
| cat = "default" | |
| n = name.upper() | |
| if "EMBED" in n or "TOKEN" in n: cat = "embed" | |
| elif "PE" in n or "POSITIONAL" in n: cat = "pe" | |
| elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn" | |
| elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn" | |
| elif "NORM" in n or "RESIDUAL" in n: cat = "norm" | |
| elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss" | |
| elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer" | |
| elif "CROSS" in n: cat = "cross" | |
| elif "MASK" in n: cat = "mask" | |
| cards.append(f""" | |
| <details class="calc-card cat-{cat}" data-idx="{i}"> | |
| <summary class="calc-header"> | |
| <span class="step-num">#{i+1}</span> | |
| <span class="step-name cat-label-{cat}">{name.replace('_',' ')}</span> | |
| {shape_badge} | |
| <span class="toggle-arrow">▶</span> | |
| </summary> | |
| <div class="calc-body"> | |
| {formula_html} | |
| {note_html} | |
| <div class="matrix-wrap">{matrix_html}</div> | |
| </div> | |
| </details>""") | |
| return "\n".join(cards) | |
| # ───────────────────────────────────────────── | |
| # Attention heatmap HTML | |
| # ───────────────────────────────────────────── | |
| def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"): | |
| """weights: 2D list [tgt, src]""" | |
| if not weights: | |
| return "" | |
| rows_html = "" | |
| for i, row in enumerate(weights): | |
| cells = "" | |
| for j, w in enumerate(row): | |
| alpha = min(float(w), 1.0) | |
| cells += f'<td class="heat-cell" style="--a:{alpha:.3f}" title="{row_labels[i] if i<len(row_labels) else i}→{col_labels[j] if j<len(col_labels) else j}: {alpha:.3f}">{alpha:.2f}</td>' | |
| lbl = row_labels[i] if i < len(row_labels) else str(i) | |
| rows_html += f'<tr><td class="heat-label">{lbl}</td>{cells}</tr>' | |
| header = '<tr><td></td>' + "".join(f'<td class="heat-col-label">{c}</td>' for c in col_labels) + '</tr>' | |
| return f""" | |
| <div class="heatmap-container"> | |
| <div class="heatmap-title">{title}</div> | |
| <table class="heatmap">{header}{rows_html}</table> | |
| </div>""" | |
| # ───────────────────────────────────────────── | |
| # Decoding steps HTML | |
| # ───────────────────────────────────────────── | |
| def decode_steps_html(step_logs, src_tokens): | |
| if not step_logs: | |
| return "" | |
| html = '<div class="decode-steps"><div class="decode-title">🔁 Auto-regressive Decoding Steps</div>' | |
| for s in step_logs: | |
| step = s.get("step", 0) | |
| tokens_so_far = s.get("tokens_so_far", []) | |
| top5 = s.get("top5", []) | |
| chosen = s.get("chosen_token", "?") | |
| prob = s.get("chosen_prob", 0) | |
| bars = "" | |
| if top5: | |
| max_p = max(t["prob"] for t in top5) or 1 | |
| for t in top5: | |
| pct = t["prob"] / max_p * 100 | |
| is_chosen = "chosen" if t["token"] == chosen else "" | |
| bars += f"""<div class="bar-row {is_chosen}"> | |
| <span class="bar-label">{t['token']}</span> | |
| <div class="bar" style="width:{pct:.1f}%"></div> | |
| <span class="bar-prob">{t['prob']:.3f}</span> | |
| </div>""" | |
| cross_heat = "" | |
| if s.get("cross_attn") and src_tokens: | |
| attn_mat = s["cross_attn"] # [num_heads][T_q][T_src] | |
| if attn_mat and attn_mat[0]: | |
| # Take head-0, last decoded position → [T_src] floats | |
| last_pos_attn = attn_mat[0][-1] # [T_src] | |
| last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap | |
| cross_heat = attention_heatmap_html( | |
| last_row, [chosen], src_tokens, | |
| title=f"Cross-Attn: '{chosen}' → English" | |
| ) | |
| html += f""" | |
| <div class="decode-step"> | |
| <div class="decode-step-header"> | |
| <span class="step-badge">Step {step+1}</span> | |
| <span class="step-ctx">Context: {' '.join(tokens_so_far)}</span> | |
| <span class="step-arrow">→</span> | |
| <span class="step-chosen">'{chosen}'</span> | |
| <span class="step-prob">{prob:.3f}</span> | |
| </div> | |
| <div class="step-bars">{bars}</div> | |
| {cross_heat} | |
| </div>""" | |
| html += "</div>" | |
| return html | |
| # ───────────────────────────────────────────── | |
| # Architecture SVG | |
| # ───────────────────────────────────────────── | |
| ARCH_SVG = """ | |
| <div id="arch-diagram"> | |
| <svg viewBox="0 0 820 900" xmlns="http://www.w3.org/2000/svg" style="width:100%;max-width:820px;margin:auto;display:block"> | |
| <defs> | |
| <marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto"> | |
| <path d="M0,0 L0,6 L8,3 z" fill="#64ffda"/> | |
| </marker> | |
| <filter id="glow"> | |
| <feGaussianBlur stdDeviation="2" result="blur"/> | |
| <feMerge><feMergeNode in="blur"/><feMergeNode in="SourceGraphic"/></feMerge> | |
| </filter> | |
| </defs> | |
| <!-- Background --> | |
| <rect width="820" height="900" fill="#0a0f1e" rx="12"/> | |
| <!-- Title --> | |
| <text x="410" y="35" text-anchor="middle" fill="#64ffda" font-size="16" font-family="monospace" font-weight="bold">Transformer Architecture — English → Bengali</text> | |
| <!-- ── ENCODER (left) ── --> | |
| <rect x="40" y="60" width="330" height="720" rx="10" fill="#0d1b2a" stroke="#1e4d6b" stroke-width="1.5"/> | |
| <text x="205" y="90" text-anchor="middle" fill="#4fc3f7" font-size="13" font-weight="bold">ENCODER</text> | |
| <!-- Input Embedding --> | |
| <rect x="70" y="110" width="270" height="40" rx="6" fill="#1a3a5c" stroke="#4fc3f7" stroke-width="1.5"/> | |
| <text x="205" y="135" text-anchor="middle" fill="#e0f7fa" font-size="11">Input Embedding + Positional Encoding</text> | |
| <!-- Encoder Layer Box --> | |
| <rect x="60" y="175" width="290" height="340" rx="8" fill="#112233" stroke="#1e4d6b" stroke-width="1" stroke-dasharray="4"/> | |
| <text x="100" y="198" fill="#607d8b" font-size="10">Encoder Layer × N</text> | |
| <!-- Multi-Head Self-Attention --> | |
| <rect x="80" y="210" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/> | |
| <text x="205" y="232" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Multi-Head Self-Attention</text> | |
| <text x="205" y="248" text-anchor="middle" fill="#80deea" font-size="9">Q = K = V = encoder input</text> | |
| <!-- Add & Norm 1 --> | |
| <rect x="80" y="278" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/> | |
| <text x="205" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text> | |
| <!-- FFN --> | |
| <rect x="80" y="328" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/> | |
| <text x="205" y="350" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Feed-Forward Network</text> | |
| <text x="205" y="366" text-anchor="middle" fill="#80deea" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text> | |
| <!-- Add & Norm 2 --> | |
| <rect x="80" y="396" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/> | |
| <text x="205" y="416" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text> | |
| <!-- Encoder output arrow down --> | |
| <line x1="205" y1="455" x2="205" y2="550" stroke="#64ffda" stroke-width="1.5" marker-end="url(#arr)"/> | |
| <text x="215" y="510" fill="#64ffda" font-size="9">K, V to</text> | |
| <text x="215" y="522" fill="#64ffda" font-size="9">decoder</text> | |
| <!-- Encoder output box --> | |
| <rect x="70" y="555" width="270" height="40" rx="6" fill="#0d2b1a" stroke="#00e676" stroke-width="1.5"/> | |
| <text x="205" y="580" text-anchor="middle" fill="#a5d6a7" font-size="11">Encoder Output (K, V)</text> | |
| <!-- ── DECODER (right) ── --> | |
| <rect x="450" y="60" width="330" height="720" rx="10" fill="#1a0d2a" stroke="#4a1b6b" stroke-width="1.5"/> | |
| <text x="615" y="90" text-anchor="middle" fill="#ce93d8" font-size="13" font-weight="bold">DECODER</text> | |
| <!-- Target Embedding --> | |
| <rect x="480" y="110" width="270" height="40" rx="6" fill="#3a1a5c" stroke="#ce93d8" stroke-width="1.5"/> | |
| <text x="615" y="135" text-anchor="middle" fill="#f3e5f5" font-size="11">Target Embedding + Positional Encoding</text> | |
| <!-- Decoder Layer Box --> | |
| <rect x="470" y="175" width="290" height="460" rx="8" fill="#1a1133" stroke="#4a1b6b" stroke-width="1" stroke-dasharray="4"/> | |
| <text x="510" y="198" fill="#607d8b" font-size="10">Decoder Layer × N</text> | |
| <!-- Masked MHA --> | |
| <rect x="490" y="210" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/> | |
| <text x="615" y="232" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Masked Multi-Head Self-Attention</text> | |
| <text x="615" y="248" text-anchor="middle" fill="#ce93d8" font-size="9">Q = K = V = decoder input (causal mask)</text> | |
| <!-- Add & Norm D1 --> | |
| <rect x="490" y="278" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/> | |
| <text x="615" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text> | |
| <!-- Cross-Attention --> | |
| <rect x="490" y="328" width="250" height="60" rx="6" fill="#1b2b4b" stroke="#29b6f6" stroke-width="2" filter="url(#glow)"/> | |
| <text x="615" y="350" text-anchor="middle" fill="#e1f5fe" font-size="11" font-weight="bold">Cross-Attention</text> | |
| <text x="615" y="366" text-anchor="middle" fill="#81d4fa" font-size="9">Q = decoder | K, V = encoder</text> | |
| <text x="615" y="380" text-anchor="middle" fill="#29b6f6" font-size="9" font-weight="bold">← KEY CONNECTION</text> | |
| <!-- Add & Norm D2 --> | |
| <rect x="490" y="408" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/> | |
| <text x="615" y="428" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text> | |
| <!-- FFN Decoder --> | |
| <rect x="490" y="458" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/> | |
| <text x="615" y="480" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Feed-Forward Network</text> | |
| <text x="615" y="496" text-anchor="middle" fill="#ce93d8" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text> | |
| <!-- Add & Norm D3 --> | |
| <rect x="490" y="526" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/> | |
| <text x="615" y="546" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text> | |
| <!-- Output Linear + Softmax --> | |
| <rect x="480" y="600" width="270" height="40" rx="6" fill="#2b1b0a" stroke="#ffb300" stroke-width="1.5"/> | |
| <text x="615" y="625" text-anchor="middle" fill="#fff8e1" font-size="11">Linear + Softmax → Bengali Token</text> | |
| <!-- Cross-attention arrow from encoder to decoder --> | |
| <path d="M340,590 Q410,480 490,368" stroke="#29b6f6" stroke-width="2" fill="none" | |
| stroke-dasharray="6,3" marker-end="url(#arr)"/> | |
| <text x="390" y="500" fill="#29b6f6" font-size="9" transform="rotate(-50,390,500)">K, V flow</text> | |
| <!-- Input arrow --> | |
| <line x1="205" y1="840" x2="205" y2="780" stroke="#4fc3f7" stroke-width="1.5" marker-end="url(#arr)"/> | |
| <text x="205" y="858" text-anchor="middle" fill="#4fc3f7" font-size="11">English Input</text> | |
| <line x1="615" y1="840" x2="615" y2="660" stroke="#ce93d8" stroke-width="1.5" marker-end="url(#arr)"/> | |
| <text x="615" y="858" text-anchor="middle" fill="#ce93d8" font-size="11">Bengali Output</text> | |
| <!-- Legend --> | |
| <rect x="60" y="870" width="700" height="20" rx="4" fill="#0a1520" stroke="#1e2d3d" stroke-width="1"/> | |
| <circle cx="80" cy="880" r="4" fill="#26c6da"/><text x="88" y="884" fill="#80deea" font-size="8">Self-Attention</text> | |
| <circle cx="160" cy="880" r="4" fill="#29b6f6"/><text x="168" y="884" fill="#81d4fa" font-size="8">Cross-Attention</text> | |
| <circle cx="250" cy="880" r="4" fill="#ab47bc"/><text x="258" y="884" fill="#ce93d8" font-size="8">Masked Attn</text> | |
| <circle cx="350" cy="880" r="4" fill="#00e676"/><text x="358" y="884" fill="#a5d6a7" font-size="8">Enc→Dec K,V</text> | |
| <circle cx="450" cy="880" r="4" fill="#ffb300"/><text x="458" y="884" fill="#fff8e1" font-size="8">Output Layer</text> | |
| </svg> | |
| </div> | |
| """ | |
| # ───────────────────────────────────────────── | |
| # CSS + JS | |
| # ───────────────────────────────────────────── | |
| CUSTOM_CSS = """ | |
| /* ── fonts ── */ | |
| @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap'); | |
| :root { | |
| --bg: #07090f; | |
| --bg2: #0d1120; | |
| --bg3: #111827; | |
| --card: #141c2e; | |
| --border: #1e2d45; | |
| --accent: #64ffda; | |
| --accent2: #29b6f6; | |
| --accent3: #ce93d8; | |
| --accent4: #ffb300; | |
| --text: #e2e8f0; | |
| --muted: #64748b; | |
| --embed: #4fc3f7; | |
| --pe: #26c6da; | |
| --attn: #f06292; | |
| --ffn: #aed581; | |
| --norm: #90a4ae; | |
| --loss: #ef9a9a; | |
| --infer: #80cbc4; | |
| --cross: #29b6f6; | |
| --mask: #ffb300; | |
| } | |
| body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; } | |
| h1, h2, h3 { font-family: 'Syne', sans-serif !important; } | |
| /* ── tabs ── */ | |
| .tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; } | |
| .tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); } | |
| /* ── inputs ── */ | |
| input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; } | |
| input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; } | |
| button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; } | |
| button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; } | |
| /* ── calc cards ── */ | |
| details.calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; } | |
| details.calc-card > summary { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; list-style: none; transition: background 0.15s; } | |
| details.calc-card > summary::-webkit-details-marker { display: none; } | |
| details.calc-card > summary::marker { display: none; } | |
| details.calc-card > summary:hover { background: rgba(255,255,255,0.03); } | |
| details.calc-card > .calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); } | |
| .step-num { color: var(--muted); font-size: 11px; min-width: 28px; } | |
| .step-name { font-weight: 600; font-size: 12px; flex: 1; } | |
| .toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; } | |
| details[open] .toggle-arrow { transform: rotate(90deg); } | |
| .shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); } | |
| .formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); } | |
| .step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; } | |
| /* category colors */ | |
| .cat-label-embed { color: var(--embed); } | |
| .cat-label-pe { color: var(--pe); } | |
| .cat-label-attn { color: var(--attn); } | |
| .cat-label-ffn { color: var(--ffn); } | |
| .cat-label-norm { color: var(--norm); } | |
| .cat-label-loss { color: var(--loss); } | |
| .cat-label-infer { color: var(--infer); } | |
| .cat-label-cross { color: var(--cross); } | |
| .cat-label-mask { color: var(--mask); } | |
| .cat-label-default{ color: var(--text); } | |
| .cat-embed { border-left: 3px solid var(--embed); } | |
| .cat-pe { border-left: 3px solid var(--pe); } | |
| .cat-attn { border-left: 3px solid var(--attn); } | |
| .cat-ffn { border-left: 3px solid var(--ffn); } | |
| .cat-norm { border-left: 3px solid var(--norm); } | |
| .cat-loss { border-left: 3px solid var(--loss); } | |
| .cat-infer { border-left: 3px solid var(--infer); } | |
| .cat-cross { border-left: 3px solid var(--cross); } | |
| .cat-mask { border-left: 3px solid var(--mask); } | |
| .cat-default{ border-left: 3px solid var(--border); } | |
| /* ── matrix tables ── */ | |
| .matrix-wrap { overflow-x: auto; } | |
| .matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; } | |
| .mat-cell { | |
| padding: 2px 5px; text-align: right; min-width: 48px; | |
| background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%)); | |
| color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05); | |
| } | |
| .mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; } | |
| .dict-table { font-size: 11px; width: 100%; } | |
| .dict-key { color: var(--accent); padding: 2px 8px 2px 0; } | |
| .dict-val { color: var(--text); padding: 2px; } | |
| .scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; } | |
| /* ── heatmap ── */ | |
| .heatmap-container { margin: 8px 0; } | |
| .heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; } | |
| .heatmap { border-collapse: collapse; font-size: 10px; } | |
| .heat-cell { | |
| width: 36px; height: 24px; text-align: center; | |
| background: rgba(41, 182, 246, calc(var(--a, 0))); | |
| border: 1px solid rgba(255,255,255,0.04); | |
| color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%)); | |
| font-size: 9px; cursor: default; | |
| } | |
| .heat-cell:hover { outline: 1px solid var(--accent); } | |
| .heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; } | |
| .heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; } | |
| /* ── decode steps ── */ | |
| .decode-steps { margin-top: 12px; } | |
| .decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); } | |
| .decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); } | |
| .decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; } | |
| .step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; } | |
| .step-ctx { color: var(--muted); font-size: 11px; } | |
| .step-arrow { color: var(--accent4); } | |
| .step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; } | |
| .step-prob { color: var(--accent4); font-size: 11px; } | |
| .step-bars { margin: 4px 0; } | |
| .bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; } | |
| .bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; } | |
| .bar-row.chosen .bar { background: var(--accent3) !important; } | |
| .bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } | |
| .bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; } | |
| .bar-prob { font-size: 10px; color: var(--muted); } | |
| /* ── loss chart ── */ | |
| #loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; } | |
| /* ── arch diagram ── */ | |
| #arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; } | |
| /* ── result banner ── */ | |
| .result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; } | |
| .result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; } | |
| .result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; } | |
| .result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; } | |
| /* ── misc ── */ | |
| .gradio-html { background: transparent !important; } | |
| .panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; } | |
| .log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; } | |
| """ | |
| CUSTOM_JS = """ | |
| // <details> elements toggle natively — JS only needed for expand-all/collapse-all/filter | |
| window._expandAll = function() { | |
| document.querySelectorAll('details.calc-card').forEach(d => d.open = true); | |
| }; | |
| window._collapseAll = function() { | |
| document.querySelectorAll('details.calc-card').forEach(d => d.open = false); | |
| }; | |
| window._filterCards = function(cat) { | |
| document.querySelectorAll('details.calc-card').forEach(d => { | |
| d.style.display = (!cat || d.classList.contains('cat-'+cat)) ? '' : 'none'; | |
| }); | |
| }; | |
| // Toolbar button delegation | |
| document.addEventListener('click', function(e) { | |
| const btn = e.target.closest('[data-ga]'); | |
| if (!btn) return; | |
| const a = btn.dataset.ga; | |
| if (a === 'expand') window._expandAll(); | |
| else if (a === 'collapse') window._collapseAll(); | |
| else if (a.startsWith('filter:')) window._filterCards(a.slice(7)); | |
| }, true); | |
| """ | |
| # ───────────────────────────────────────────── | |
| # Pure-SVG loss curve (no JS/canvas needed) | |
| # ───────────────────────────────────────────── | |
| def _loss_svg(losses): | |
| if not losses: | |
| return "" | |
| W, H = 580, 200 | |
| pl, pr, pt, pb = 52, 16, 16, 36 | |
| pw, ph = W - pl - pr, H - pt - pb | |
| mn, mx = min(losses), max(losses) | |
| rng = mx - mn or 1 | |
| n = len(losses) | |
| def px(i): return pl + (i / max(n - 1, 1)) * pw | |
| def py(v): return pt + ph - ((v - mn) / rng) * ph | |
| # Grid + Y labels | |
| grid = "" | |
| for k in range(5): | |
| v = mn + (k / 4) * rng | |
| y = py(v) | |
| grid += f'<line x1="{pl}" y1="{y:.1f}" x2="{pl+pw}" y2="{y:.1f}" stroke="#1e2d45" stroke-width="0.5"/>' | |
| grid += f'<text x="{pl-4}" y="{y+4:.1f}" text-anchor="end" fill="#64748b" font-size="9" font-family="monospace">{v:.3f}</text>' | |
| # Polyline points | |
| pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses)) | |
| fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}" | |
| # X labels | |
| xlabels = "" | |
| for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)): | |
| xlabels += f'<text x="{px(idx):.1f}" y="{H-4}" text-anchor="middle" fill="#64748b" font-size="9" font-family="monospace">E{idx+1}</text>' | |
| return f""" | |
| <div style="background:#0d1120;border:1px solid #1e2d45;border-radius:8px;padding:12px;margin-top:8px"> | |
| <div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">📉 Training Loss Curve</div> | |
| <svg width="{W}" height="{H}" style="display:block;max-width:100%"> | |
| <defs> | |
| <linearGradient id="lcg" x1="0" y1="0" x2="{W}" y2="0" gradientUnits="userSpaceOnUse"> | |
| <stop offset="0%" stop-color="#64ffda"/><stop offset="100%" stop-color="#29b6f6"/> | |
| </linearGradient> | |
| </defs> | |
| <rect width="{W}" height="{H}" fill="#0d1120"/> | |
| {grid} | |
| <polygon points="{fill_pts}" fill="rgba(100,255,218,0.08)"/> | |
| <polyline points="{pts}" fill="none" stroke="url(#lcg)" stroke-width="2.5" stroke-linejoin="round"/> | |
| {xlabels} | |
| </svg> | |
| </div>""" | |
| # ───────────────────────────────────────────── | |
| # Gradio callbacks | |
| # ───────────────────────────────────────────── | |
| def do_train(epochs_str, progress=gr.Progress()): | |
| global MODEL, LOSS_HISTORY, IS_TRAINED | |
| try: | |
| epochs = int(epochs_str) | |
| except: | |
| epochs = 30 | |
| losses = [] | |
| def cb(ep, total, loss): | |
| losses.append(loss) | |
| progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}") | |
| MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb) | |
| IS_TRAINED = True | |
| chart_html = _loss_svg(LOSS_HISTORY) | |
| return ( | |
| f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}", | |
| chart_html | |
| ) | |
| def do_training_viz(en_sentence, bn_sentence): | |
| model = get_or_init_model() | |
| if not en_sentence.strip(): | |
| return "<p style='color:red'>Please enter an English sentence.</p>", "", "" | |
| if not bn_sentence.strip(): | |
| bn_sentence = "আমি তোমাকে ভালোবাসি" | |
| result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE) | |
| # Attention heatmap (cross-attn layer 0, head 0) | |
| meta = result.get("meta", {}) | |
| attn_html = "" | |
| src_tokens = result.get("src_tokens", []) | |
| tgt_tokens = result.get("tgt_tokens", []) | |
| result_banner = f""" | |
| <div class="result-banner"> | |
| <div class="result-label">English Input</div> | |
| <div class="result-en">"{en_sentence}"</div> | |
| <div class="result-label" style="margin-top:8px">Bengali (Teacher-forced)</div> | |
| <div class="result-bn">{bn_sentence}</div> | |
| <div style="margin-top:8px;color:var(--loss);font-size:13px"> | |
| 📉 Loss: <strong>{result['loss']:.4f}</strong> | |
| </div> | |
| </div>""" | |
| calc_html = f""" | |
| <div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap"> | |
| <button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button> | |
| <button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button> | |
| {"".join(f'<button data-ga="filter:{cat}" style="background:var(--card);color:var(--cat-{cat},var(--text));border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">{cat}</button>' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])} | |
| <button data-ga="filter:" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">show all</button> | |
| </div> | |
| <div class="log-container"> | |
| {calc_log_to_html(result.get('calc_log', []))} | |
| </div>""" | |
| return result_banner, calc_html, attn_html | |
| def do_inference_viz(en_sentence, decode_method): | |
| model = get_or_init_model() | |
| if not en_sentence.strip(): | |
| return "<p style='color:red'>Please enter an English sentence.</p>", "", "" | |
| result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method) | |
| result_banner = f""" | |
| <div class="result-banner"> | |
| <div class="result-label">English Input</div> | |
| <div class="result-en">"{en_sentence}"</div> | |
| <div class="result-label" style="margin-top:8px">Bengali Translation ({decode_method})</div> | |
| <div class="result-bn">{result['translation'] or '(no output)'}</div> | |
| <div style="margin-top:6px;color:var(--muted);font-size:11px"> | |
| Tokens: {' → '.join(result['output_tokens'])} | |
| </div> | |
| </div>""" | |
| decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", [])) | |
| calc_html = f""" | |
| <div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap"> | |
| <button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button> | |
| <button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button> | |
| </div> | |
| <div class="log-container"> | |
| {calc_log_to_html(result.get('calc_log', []))} | |
| </div>""" | |
| return result_banner + decode_html, calc_html, "" | |
| # ───────────────────────────────────────────── | |
| # Build UI | |
| # ───────────────────────────────────────────── | |
| def build_ui(): | |
| _theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate") | |
| with gr.Blocks( | |
| title="Transformer Visualizer — EN→BN", | |
| css=CUSTOM_CSS, | |
| js=f"() => {{ {CUSTOM_JS} }}", | |
| theme=_theme, | |
| ) as demo: | |
| gr.HTML(""" | |
| <div style="text-align:center;padding:24px 0 12px;border-bottom:1px solid #1e2d45;margin-bottom:16px"> | |
| <div style="font-family:'Syne',sans-serif;font-size:28px;font-weight:800; | |
| background:linear-gradient(135deg,#64ffda,#29b6f6,#ce93d8); | |
| -webkit-background-clip:text;-webkit-text-fill-color:transparent;letter-spacing:2px"> | |
| TRANSFORMER VISUALIZER | |
| </div> | |
| <div style="color:#64748b;font-size:12px;letter-spacing:3px;margin-top:4px;font-family:'JetBrains Mono',monospace"> | |
| ENGLISH → BENGALI · EVERY CALCULATION EXPOSED | |
| </div> | |
| </div> | |
| """) | |
| with gr.Tabs(): | |
| # ── TAB 0: Architecture ────────────────── | |
| with gr.Tab("🏗️ Architecture"): | |
| gr.HTML(ARCH_SVG) | |
| gr.HTML(""" | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;margin-top:12px"> | |
| <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px"> | |
| <div style="color:#4fc3f7;font-weight:700;margin-bottom:8px">📌 Encoder Flow</div> | |
| <div style="color:#94a3b8;font-size:12px;line-height:1.8"> | |
| 1. English tokens → Embedding (d_model=64)<br> | |
| 2. + Positional Encoding (sin/cos)<br> | |
| 3. Multi-Head Self-Attention (4 heads)<br> | |
| 4. Add & LayerNorm<br> | |
| 5. Feed-Forward (64→128→64)<br> | |
| 6. Add & LayerNorm<br> | |
| 7. Repeat × 2 layers<br> | |
| 8. Output K, V for decoder | |
| </div> | |
| </div> | |
| <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px"> | |
| <div style="color:#ce93d8;font-weight:700;margin-bottom:8px">📌 Decoder Flow</div> | |
| <div style="color:#94a3b8;font-size:12px;line-height:1.8"> | |
| 1. Bengali tokens → Embedding<br> | |
| 2. + Positional Encoding<br> | |
| 3. Masked MHA (future tokens blocked)<br> | |
| 4. Add & LayerNorm<br> | |
| 5. Cross-Attention: Q←decoder, K,V←encoder<br> | |
| 6. Add & LayerNorm<br> | |
| 7. Feed-Forward<br> | |
| 8. Linear → Softmax → Bengali token | |
| </div> | |
| </div> | |
| </div> | |
| """) | |
| # ── TAB 1: Train ───────────────────────── | |
| with gr.Tab("🏋️ Train Model"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML('<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">Quick Train</div>') | |
| epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1) | |
| train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary") | |
| train_status = gr.HTML() | |
| with gr.Column(scale=2): | |
| loss_chart = gr.HTML() | |
| train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart]) | |
| # ── TAB 2: Training Step Viz ────────────── | |
| with gr.Tab("🔬 Training Step"): | |
| gr.HTML('<div style="color:#ef9a9a;font-size:12px;margin-bottom:12px">📚 Shows <strong>teacher-forcing</strong>: ground-truth Bengali tokens are fed to decoder, loss + gradients computed.</div>') | |
| with gr.Row(): | |
| en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you") | |
| bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি") | |
| run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary") | |
| result_html_t = gr.HTML() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| calc_html_t = gr.HTML() | |
| with gr.Column(scale=1): | |
| attn_html_t = gr.HTML() | |
| run_train_viz.click(do_training_viz, | |
| inputs=[en_in_t, bn_in_t], | |
| outputs=[result_html_t, calc_html_t, attn_html_t]) | |
| # ── TAB 3: Inference Viz ────────────────── | |
| with gr.Tab("⚡ Inference"): | |
| gr.HTML('<div style="color:#80cbc4;font-size:12px;margin-bottom:12px">🤖 Shows <strong>auto-regressive decoding</strong>: model generates Bengali token by token, no ground truth needed.</div>') | |
| with gr.Row(): | |
| en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you") | |
| decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method") | |
| run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary") | |
| result_html_i = gr.HTML() | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| calc_html_i = gr.HTML() | |
| with gr.Column(scale=1): | |
| attn_html_i = gr.HTML() | |
| run_infer.click(do_inference_viz, | |
| inputs=[en_in_i, decode_radio], | |
| outputs=[result_html_i, calc_html_i, attn_html_i]) | |
| # ── TAB 4: Examples ────────────────────── | |
| with gr.Tab("📖 Examples"): | |
| gr.HTML(""" | |
| <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:16px"> | |
| <div style="color:#64ffda;font-weight:700;margin-bottom:12px">Try these sentences:</div> | |
| <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px"> | |
| """ + "".join( | |
| f'<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:6px;padding:8px">' | |
| f'<div style="color:#4fc3f7;font-size:12px">{en}</div>' | |
| f'<div style="color:#ce93d8;font-size:13px;font-weight:600">{bn}</div>' | |
| f'</div>' | |
| for en, bn in PARALLEL_DATA[:12] | |
| ) + "</div></div>") | |
| return demo | |
| demo = build_ui() | |
| demo.launch(server_name="0.0.0.0") | |