""" app.py Gradio Space: Interactive Transformer Visualizer — English → Bengali """ import gradio as gr import torch import json import os import numpy as np from pathlib import Path from transformer import Transformer, CalcLog from training import build_model, run_training, visualize_training_step, collate_batch from inference import visualize_inference from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX # ───────────────────────────────────────────── # Global state # ───────────────────────────────────────────── DEVICE = "cpu" src_v, tgt_v = get_vocabs() MODEL: Transformer = None LOSS_HISTORY = [] IS_TRAINED = False def get_or_init_model(): global MODEL if MODEL is None: MODEL = build_model(len(src_v), len(tgt_v), DEVICE) return MODEL # ───────────────────────────────────────────── # HTML renderer for calc log # ───────────────────────────────────────────── def render_matrix_html(val, max_rows=6, max_cols=8): """Convert a nested list / scalar to an HTML matrix table.""" if isinstance(val, (int, float)): return f'{val:.5f}' if isinstance(val, dict): rows = "".join( f'{k}{v}' for k, v in val.items() ) return f'{rows}
' if isinstance(val, list): # 0-D or scalar list if len(val) == 0: return "empty" # 1-D if not isinstance(val[0], list): clipped = val[:max_cols*2] cells = "".join( f'{v:.4f}' if isinstance(v, float) else f'{v}' for v in clipped ) suffix = f'…+{len(val)-len(clipped)}' if len(val) > len(clipped) else "" return f'{cells}{suffix}
' # 2-D rows_html = "" display_rows = val[:max_rows] for row in display_rows: display_cols = row[:max_cols] cells = "".join( f'' f'{float(c):.3f}' if isinstance(c, (int, float)) else f'{c}' for c in display_cols ) suffix = f'…' if len(row) > max_cols else "" rows_html += f"{cells}{suffix}" if len(val) > max_rows: rows_html += f'…{len(val)-max_rows} more rows' return f'{rows_html}
' return f'{str(val)[:200]}' def calc_log_to_html(steps): """Turn CalcLog steps into rich HTML accordion.""" if not steps: return "

No calculation log yet.

" cards = [] for i, step in enumerate(steps): name = step.get("name", f"step_{i}") formula = step.get("formula", "") note = step.get("note", "") shape = step.get("shape") val = step.get("value") shape_badge = f'{shape}' if shape else "" formula_html = f'
⟨ {formula} ⟩
' if formula else "" note_html = f'
ℹ {note}
' if note else "" matrix_html = render_matrix_html(val) if val is not None else "" # Color category by name prefix cat = "default" n = name.upper() if "EMBED" in n or "TOKEN" in n: cat = "embed" elif "PE" in n or "POSITIONAL" in n: cat = "pe" elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn" elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn" elif "NORM" in n or "RESIDUAL" in n: cat = "norm" elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss" elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer" elif "CROSS" in n: cat = "cross" elif "MASK" in n: cat = "mask" cards.append(f"""
#{i+1} {name.replace('_',' ')} {shape_badge}
{formula_html} {note_html}
{matrix_html}
""") return "\n".join(cards) # ───────────────────────────────────────────── # Attention heatmap HTML # ───────────────────────────────────────────── def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"): """weights: 2D list [tgt, src]""" if not weights: return "" rows_html = "" for i, row in enumerate(weights): cells = "" for j, w in enumerate(row): alpha = min(float(w), 1.0) cells += f'{alpha:.2f}' lbl = row_labels[i] if i < len(row_labels) else str(i) rows_html += f'{lbl}{cells}' header = '' + "".join(f'{c}' for c in col_labels) + '' return f"""
{title}
{header}{rows_html}
""" # ───────────────────────────────────────────── # Decoding steps HTML # ───────────────────────────────────────────── def decode_steps_html(step_logs, src_tokens): if not step_logs: return "" html = '
🔁 Auto-regressive Decoding Steps
' for s in step_logs: step = s.get("step", 0) tokens_so_far = s.get("tokens_so_far", []) top5 = s.get("top5", []) chosen = s.get("chosen_token", "?") prob = s.get("chosen_prob", 0) bars = "" if top5: max_p = max(t["prob"] for t in top5) or 1 for t in top5: pct = t["prob"] / max_p * 100 is_chosen = "chosen" if t["token"] == chosen else "" bars += f"""
{t['token']}
{t['prob']:.3f}
""" cross_heat = "" if s.get("cross_attn") and src_tokens: attn_mat = s["cross_attn"] # [num_heads][T_q][T_src] if attn_mat and attn_mat[0]: # Take head-0, last decoded position → [T_src] floats last_pos_attn = attn_mat[0][-1] # [T_src] last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap cross_heat = attention_heatmap_html( last_row, [chosen], src_tokens, title=f"Cross-Attn: '{chosen}' → English" ) html += f"""
Step {step+1} Context: {' '.join(tokens_so_far)} '{chosen}' {prob:.3f}
{bars}
{cross_heat}
""" html += "
" return html # ───────────────────────────────────────────── # Architecture SVG # ───────────────────────────────────────────── ARCH_SVG = """
Transformer Architecture — English → Bengali ENCODER Input Embedding + Positional Encoding Encoder Layer × N Multi-Head Self-Attention Q = K = V = encoder input Add & Norm Feed-Forward Network FFN(x) = max(0, xW₁+b₁)W₂+b₂ Add & Norm K, V to decoder Encoder Output (K, V) DECODER Target Embedding + Positional Encoding Decoder Layer × N Masked Multi-Head Self-Attention Q = K = V = decoder input (causal mask) Add & Norm Cross-Attention Q = decoder | K, V = encoder ← KEY CONNECTION Add & Norm Feed-Forward Network FFN(x) = max(0, xW₁+b₁)W₂+b₂ Add & Norm Linear + Softmax → Bengali Token K, V flow English Input Bengali Output Self-Attention Cross-Attention Masked Attn Enc→Dec K,V Output Layer
""" # ───────────────────────────────────────────── # CSS + JS # ───────────────────────────────────────────── CUSTOM_CSS = """ /* ── fonts ── */ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap'); :root { --bg: #07090f; --bg2: #0d1120; --bg3: #111827; --card: #141c2e; --border: #1e2d45; --accent: #64ffda; --accent2: #29b6f6; --accent3: #ce93d8; --accent4: #ffb300; --text: #e2e8f0; --muted: #64748b; --embed: #4fc3f7; --pe: #26c6da; --attn: #f06292; --ffn: #aed581; --norm: #90a4ae; --loss: #ef9a9a; --infer: #80cbc4; --cross: #29b6f6; --mask: #ffb300; } body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; } h1, h2, h3 { font-family: 'Syne', sans-serif !important; } /* ── tabs ── */ .tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; } .tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); } /* ── inputs ── */ input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; } input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; } button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; } button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; } /* ── calc cards ── */ details.calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; } details.calc-card > summary { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; list-style: none; transition: background 0.15s; } details.calc-card > summary::-webkit-details-marker { display: none; } details.calc-card > summary::marker { display: none; } details.calc-card > summary:hover { background: rgba(255,255,255,0.03); } details.calc-card > .calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); } .step-num { color: var(--muted); font-size: 11px; min-width: 28px; } .step-name { font-weight: 600; font-size: 12px; flex: 1; } .toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; } details[open] .toggle-arrow { transform: rotate(90deg); } .shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); } .formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); } .step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; } /* category colors */ .cat-label-embed { color: var(--embed); } .cat-label-pe { color: var(--pe); } .cat-label-attn { color: var(--attn); } .cat-label-ffn { color: var(--ffn); } .cat-label-norm { color: var(--norm); } .cat-label-loss { color: var(--loss); } .cat-label-infer { color: var(--infer); } .cat-label-cross { color: var(--cross); } .cat-label-mask { color: var(--mask); } .cat-label-default{ color: var(--text); } .cat-embed { border-left: 3px solid var(--embed); } .cat-pe { border-left: 3px solid var(--pe); } .cat-attn { border-left: 3px solid var(--attn); } .cat-ffn { border-left: 3px solid var(--ffn); } .cat-norm { border-left: 3px solid var(--norm); } .cat-loss { border-left: 3px solid var(--loss); } .cat-infer { border-left: 3px solid var(--infer); } .cat-cross { border-left: 3px solid var(--cross); } .cat-mask { border-left: 3px solid var(--mask); } .cat-default{ border-left: 3px solid var(--border); } /* ── matrix tables ── */ .matrix-wrap { overflow-x: auto; } .matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; } .mat-cell { padding: 2px 5px; text-align: right; min-width: 48px; background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%)); color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05); } .mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; } .dict-table { font-size: 11px; width: 100%; } .dict-key { color: var(--accent); padding: 2px 8px 2px 0; } .dict-val { color: var(--text); padding: 2px; } .scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; } /* ── heatmap ── */ .heatmap-container { margin: 8px 0; } .heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; } .heatmap { border-collapse: collapse; font-size: 10px; } .heat-cell { width: 36px; height: 24px; text-align: center; background: rgba(41, 182, 246, calc(var(--a, 0))); border: 1px solid rgba(255,255,255,0.04); color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%)); font-size: 9px; cursor: default; } .heat-cell:hover { outline: 1px solid var(--accent); } .heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; } .heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; } /* ── decode steps ── */ .decode-steps { margin-top: 12px; } .decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); } .decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); } .decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; } .step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; } .step-ctx { color: var(--muted); font-size: 11px; } .step-arrow { color: var(--accent4); } .step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; } .step-prob { color: var(--accent4); font-size: 11px; } .step-bars { margin: 4px 0; } .bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; } .bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; } .bar-row.chosen .bar { background: var(--accent3) !important; } .bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; } .bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; } .bar-prob { font-size: 10px; color: var(--muted); } /* ── loss chart ── */ #loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; } /* ── arch diagram ── */ #arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; } /* ── result banner ── */ .result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; } .result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; } .result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; } .result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; } /* ── misc ── */ .gradio-html { background: transparent !important; } .panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; } .log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; } """ CUSTOM_JS = """ //
elements toggle natively — JS only needed for expand-all/collapse-all/filter window._expandAll = function() { document.querySelectorAll('details.calc-card').forEach(d => d.open = true); }; window._collapseAll = function() { document.querySelectorAll('details.calc-card').forEach(d => d.open = false); }; window._filterCards = function(cat) { document.querySelectorAll('details.calc-card').forEach(d => { d.style.display = (!cat || d.classList.contains('cat-'+cat)) ? '' : 'none'; }); }; // Toolbar button delegation document.addEventListener('click', function(e) { const btn = e.target.closest('[data-ga]'); if (!btn) return; const a = btn.dataset.ga; if (a === 'expand') window._expandAll(); else if (a === 'collapse') window._collapseAll(); else if (a.startsWith('filter:')) window._filterCards(a.slice(7)); }, true); """ # ───────────────────────────────────────────── # Pure-SVG loss curve (no JS/canvas needed) # ───────────────────────────────────────────── def _loss_svg(losses): if not losses: return "" W, H = 580, 200 pl, pr, pt, pb = 52, 16, 16, 36 pw, ph = W - pl - pr, H - pt - pb mn, mx = min(losses), max(losses) rng = mx - mn or 1 n = len(losses) def px(i): return pl + (i / max(n - 1, 1)) * pw def py(v): return pt + ph - ((v - mn) / rng) * ph # Grid + Y labels grid = "" for k in range(5): v = mn + (k / 4) * rng y = py(v) grid += f'' grid += f'{v:.3f}' # Polyline points pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses)) fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}" # X labels xlabels = "" for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)): xlabels += f'E{idx+1}' return f"""
📉 Training Loss Curve
{grid} {xlabels}
""" # ───────────────────────────────────────────── # Gradio callbacks # ───────────────────────────────────────────── def do_train(epochs_str, progress=gr.Progress()): global MODEL, LOSS_HISTORY, IS_TRAINED try: epochs = int(epochs_str) except: epochs = 30 losses = [] def cb(ep, total, loss): losses.append(loss) progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}") MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb) IS_TRAINED = True chart_html = _loss_svg(LOSS_HISTORY) return ( f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}", chart_html ) def do_training_viz(en_sentence, bn_sentence): model = get_or_init_model() if not en_sentence.strip(): return "

Please enter an English sentence.

", "", "" if not bn_sentence.strip(): bn_sentence = "আমি তোমাকে ভালোবাসি" result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE) # Attention heatmap (cross-attn layer 0, head 0) meta = result.get("meta", {}) attn_html = "" src_tokens = result.get("src_tokens", []) tgt_tokens = result.get("tgt_tokens", []) result_banner = f"""
English Input
"{en_sentence}"
Bengali (Teacher-forced)
{bn_sentence}
📉 Loss: {result['loss']:.4f}
""" calc_html = f"""
{"".join(f'' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])}
{calc_log_to_html(result.get('calc_log', []))}
""" return result_banner, calc_html, attn_html def do_inference_viz(en_sentence, decode_method): model = get_or_init_model() if not en_sentence.strip(): return "

Please enter an English sentence.

", "", "" result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method) result_banner = f"""
English Input
"{en_sentence}"
Bengali Translation ({decode_method})
{result['translation'] or '(no output)'}
Tokens: {' → '.join(result['output_tokens'])}
""" decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", [])) calc_html = f"""
{calc_log_to_html(result.get('calc_log', []))}
""" return result_banner + decode_html, calc_html, "" # ───────────────────────────────────────────── # Build UI # ───────────────────────────────────────────── def build_ui(): _theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate") with gr.Blocks( title="Transformer Visualizer — EN→BN", css=CUSTOM_CSS, js=f"() => {{ {CUSTOM_JS} }}", theme=_theme, ) as demo: gr.HTML("""
TRANSFORMER VISUALIZER
ENGLISH → BENGALI · EVERY CALCULATION EXPOSED
""") with gr.Tabs(): # ── TAB 0: Architecture ────────────────── with gr.Tab("🏗️ Architecture"): gr.HTML(ARCH_SVG) gr.HTML("""
📌 Encoder Flow
1. English tokens → Embedding (d_model=64)
2. + Positional Encoding (sin/cos)
3. Multi-Head Self-Attention (4 heads)
4. Add & LayerNorm
5. Feed-Forward (64→128→64)
6. Add & LayerNorm
7. Repeat × 2 layers
8. Output K, V for decoder
📌 Decoder Flow
1. Bengali tokens → Embedding
2. + Positional Encoding
3. Masked MHA (future tokens blocked)
4. Add & LayerNorm
5. Cross-Attention: Q←decoder, K,V←encoder
6. Add & LayerNorm
7. Feed-Forward
8. Linear → Softmax → Bengali token
""") # ── TAB 1: Train ───────────────────────── with gr.Tab("🏋️ Train Model"): with gr.Row(): with gr.Column(scale=1): gr.HTML('
Quick Train
') epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1) train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary") train_status = gr.HTML() with gr.Column(scale=2): loss_chart = gr.HTML() train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart]) # ── TAB 2: Training Step Viz ────────────── with gr.Tab("🔬 Training Step"): gr.HTML('
📚 Shows teacher-forcing: ground-truth Bengali tokens are fed to decoder, loss + gradients computed.
') with gr.Row(): en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you") bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি") run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary") result_html_t = gr.HTML() with gr.Row(): with gr.Column(scale=2): calc_html_t = gr.HTML() with gr.Column(scale=1): attn_html_t = gr.HTML() run_train_viz.click(do_training_viz, inputs=[en_in_t, bn_in_t], outputs=[result_html_t, calc_html_t, attn_html_t]) # ── TAB 3: Inference Viz ────────────────── with gr.Tab("⚡ Inference"): gr.HTML('
🤖 Shows auto-regressive decoding: model generates Bengali token by token, no ground truth needed.
') with gr.Row(): en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you") decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method") run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary") result_html_i = gr.HTML() with gr.Row(): with gr.Column(scale=2): calc_html_i = gr.HTML() with gr.Column(scale=1): attn_html_i = gr.HTML() run_infer.click(do_inference_viz, inputs=[en_in_i, decode_radio], outputs=[result_html_i, calc_html_i, attn_html_i]) # ── TAB 4: Examples ────────────────────── with gr.Tab("📖 Examples"): gr.HTML("""
Try these sentences:
""" + "".join( f'
' f'
{en}
' f'
{bn}
' f'
' for en, bn in PARALLEL_DATA[:12] ) + "
") return demo demo = build_ui() demo.launch(server_name="0.0.0.0")