"""
app.py
Gradio Space: Interactive Transformer Visualizer — English → Bengali
"""
import gradio as gr
import torch
import json
import os
import numpy as np
from pathlib import Path
from transformer import Transformer, CalcLog
from training import build_model, run_training, visualize_training_step, collate_batch
from inference import visualize_inference
from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX
# ─────────────────────────────────────────────
# Global state
# ─────────────────────────────────────────────
DEVICE = "cpu"
src_v, tgt_v = get_vocabs()
MODEL: Transformer = None
LOSS_HISTORY = []
IS_TRAINED = False
def get_or_init_model():
global MODEL
if MODEL is None:
MODEL = build_model(len(src_v), len(tgt_v), DEVICE)
return MODEL
# ─────────────────────────────────────────────
# HTML renderer for calc log
# ─────────────────────────────────────────────
def render_matrix_html(val, max_rows=6, max_cols=8):
"""Convert a nested list / scalar to an HTML matrix table."""
if isinstance(val, (int, float)):
return f'{val:.5f} '
if isinstance(val, dict):
rows = "".join(
f'
{k} {v} '
for k, v in val.items()
)
return f''
if isinstance(val, list):
# 0-D or scalar list
if len(val) == 0:
return "empty "
# 1-D
if not isinstance(val[0], list):
clipped = val[:max_cols*2]
cells = "".join(
f'{v:.4f} '
if isinstance(v, float) else f'{v} '
for v in clipped
)
suffix = f'…+{len(val)-len(clipped)} ' if len(val) > len(clipped) else ""
return f''
# 2-D
rows_html = ""
display_rows = val[:max_rows]
for row in display_rows:
display_cols = row[:max_cols]
cells = "".join(
f''
f'{float(c):.3f} '
if isinstance(c, (int, float)) else f'{c} '
for c in display_cols
)
suffix = f'… ' if len(row) > max_cols else ""
rows_html += f"{cells}{suffix} "
if len(val) > max_rows:
rows_html += f'…{len(val)-max_rows} more rows '
return f''
return f'{str(val)[:200]}'
def calc_log_to_html(steps):
"""Turn CalcLog steps into rich HTML accordion."""
if not steps:
return "No calculation log yet.
"
cards = []
for i, step in enumerate(steps):
name = step.get("name", f"step_{i}")
formula = step.get("formula", "")
note = step.get("note", "")
shape = step.get("shape")
val = step.get("value")
shape_badge = f'{shape} ' if shape else ""
formula_html = f'⟨ {formula} ⟩
' if formula else ""
note_html = f'ℹ {note}
' if note else ""
matrix_html = render_matrix_html(val) if val is not None else ""
# Color category by name prefix
cat = "default"
n = name.upper()
if "EMBED" in n or "TOKEN" in n: cat = "embed"
elif "PE" in n or "POSITIONAL" in n: cat = "pe"
elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn"
elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn"
elif "NORM" in n or "RESIDUAL" in n: cat = "norm"
elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss"
elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer"
elif "CROSS" in n: cat = "cross"
elif "MASK" in n: cat = "mask"
cards.append(f"""
{formula_html}
{note_html}
{matrix_html}
""")
return "\n".join(cards)
# ─────────────────────────────────────────────
# Attention heatmap HTML
# ─────────────────────────────────────────────
def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"):
"""weights: 2D list [tgt, src]"""
if not weights:
return ""
rows_html = ""
for i, row in enumerate(weights):
cells = ""
for j, w in enumerate(row):
alpha = min(float(w), 1.0)
cells += f'{alpha:.2f} '
lbl = row_labels[i] if i < len(row_labels) else str(i)
rows_html += f'{lbl} {cells} '
header = ' ' + "".join(f'{c} ' for c in col_labels) + ' '
return f"""
"""
# ─────────────────────────────────────────────
# Decoding steps HTML
# ─────────────────────────────────────────────
def decode_steps_html(step_logs, src_tokens):
if not step_logs:
return ""
html = '🔁 Auto-regressive Decoding Steps
'
for s in step_logs:
step = s.get("step", 0)
tokens_so_far = s.get("tokens_so_far", [])
top5 = s.get("top5", [])
chosen = s.get("chosen_token", "?")
prob = s.get("chosen_prob", 0)
bars = ""
if top5:
max_p = max(t["prob"] for t in top5) or 1
for t in top5:
pct = t["prob"] / max_p * 100
is_chosen = "chosen" if t["token"] == chosen else ""
bars += f"""
{t['token']}
{t['prob']:.3f}
"""
cross_heat = ""
if s.get("cross_attn") and src_tokens:
attn_mat = s["cross_attn"] # [num_heads][T_q][T_src]
if attn_mat and attn_mat[0]:
# Take head-0, last decoded position → [T_src] floats
last_pos_attn = attn_mat[0][-1] # [T_src]
last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap
cross_heat = attention_heatmap_html(
last_row, [chosen], src_tokens,
title=f"Cross-Attn: '{chosen}' → English"
)
html += f"""
"""
html += "
"
return html
# ─────────────────────────────────────────────
# Architecture SVG
# ─────────────────────────────────────────────
ARCH_SVG = """
Transformer Architecture — English → Bengali
ENCODER
Input Embedding + Positional Encoding
Encoder Layer × N
Multi-Head Self-Attention
Q = K = V = encoder input
Add & Norm
Feed-Forward Network
FFN(x) = max(0, xW₁+b₁)W₂+b₂
Add & Norm
K, V to
decoder
Encoder Output (K, V)
DECODER
Target Embedding + Positional Encoding
Decoder Layer × N
Masked Multi-Head Self-Attention
Q = K = V = decoder input (causal mask)
Add & Norm
Cross-Attention
Q = decoder | K, V = encoder
← KEY CONNECTION
Add & Norm
Feed-Forward Network
FFN(x) = max(0, xW₁+b₁)W₂+b₂
Add & Norm
Linear + Softmax → Bengali Token
K, V flow
English Input
Bengali Output
Self-Attention
Cross-Attention
Masked Attn
Enc→Dec K,V
Output Layer
"""
# ─────────────────────────────────────────────
# CSS + JS
# ─────────────────────────────────────────────
CUSTOM_CSS = """
/* ── fonts ── */
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap');
:root {
--bg: #07090f;
--bg2: #0d1120;
--bg3: #111827;
--card: #141c2e;
--border: #1e2d45;
--accent: #64ffda;
--accent2: #29b6f6;
--accent3: #ce93d8;
--accent4: #ffb300;
--text: #e2e8f0;
--muted: #64748b;
--embed: #4fc3f7;
--pe: #26c6da;
--attn: #f06292;
--ffn: #aed581;
--norm: #90a4ae;
--loss: #ef9a9a;
--infer: #80cbc4;
--cross: #29b6f6;
--mask: #ffb300;
}
body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; }
h1, h2, h3 { font-family: 'Syne', sans-serif !important; }
/* ── tabs ── */
.tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; }
.tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); }
/* ── inputs ── */
input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; }
input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; }
button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; }
button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; }
/* ── calc cards ── */
details.calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; }
details.calc-card > summary { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; list-style: none; transition: background 0.15s; }
details.calc-card > summary::-webkit-details-marker { display: none; }
details.calc-card > summary::marker { display: none; }
details.calc-card > summary:hover { background: rgba(255,255,255,0.03); }
details.calc-card > .calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); }
.step-num { color: var(--muted); font-size: 11px; min-width: 28px; }
.step-name { font-weight: 600; font-size: 12px; flex: 1; }
.toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; }
details[open] .toggle-arrow { transform: rotate(90deg); }
.shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); }
.formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); }
.step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; }
/* category colors */
.cat-label-embed { color: var(--embed); }
.cat-label-pe { color: var(--pe); }
.cat-label-attn { color: var(--attn); }
.cat-label-ffn { color: var(--ffn); }
.cat-label-norm { color: var(--norm); }
.cat-label-loss { color: var(--loss); }
.cat-label-infer { color: var(--infer); }
.cat-label-cross { color: var(--cross); }
.cat-label-mask { color: var(--mask); }
.cat-label-default{ color: var(--text); }
.cat-embed { border-left: 3px solid var(--embed); }
.cat-pe { border-left: 3px solid var(--pe); }
.cat-attn { border-left: 3px solid var(--attn); }
.cat-ffn { border-left: 3px solid var(--ffn); }
.cat-norm { border-left: 3px solid var(--norm); }
.cat-loss { border-left: 3px solid var(--loss); }
.cat-infer { border-left: 3px solid var(--infer); }
.cat-cross { border-left: 3px solid var(--cross); }
.cat-mask { border-left: 3px solid var(--mask); }
.cat-default{ border-left: 3px solid var(--border); }
/* ── matrix tables ── */
.matrix-wrap { overflow-x: auto; }
.matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; }
.mat-cell {
padding: 2px 5px; text-align: right; min-width: 48px;
background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%));
color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05);
}
.mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; }
.dict-table { font-size: 11px; width: 100%; }
.dict-key { color: var(--accent); padding: 2px 8px 2px 0; }
.dict-val { color: var(--text); padding: 2px; }
.scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; }
/* ── heatmap ── */
.heatmap-container { margin: 8px 0; }
.heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; }
.heatmap { border-collapse: collapse; font-size: 10px; }
.heat-cell {
width: 36px; height: 24px; text-align: center;
background: rgba(41, 182, 246, calc(var(--a, 0)));
border: 1px solid rgba(255,255,255,0.04);
color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%));
font-size: 9px; cursor: default;
}
.heat-cell:hover { outline: 1px solid var(--accent); }
.heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; }
.heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; }
/* ── decode steps ── */
.decode-steps { margin-top: 12px; }
.decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); }
.decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); }
.decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; }
.step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; }
.step-ctx { color: var(--muted); font-size: 11px; }
.step-arrow { color: var(--accent4); }
.step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; }
.step-prob { color: var(--accent4); font-size: 11px; }
.step-bars { margin: 4px 0; }
.bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; }
.bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; }
.bar-row.chosen .bar { background: var(--accent3) !important; }
.bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
.bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; }
.bar-prob { font-size: 10px; color: var(--muted); }
/* ── loss chart ── */
#loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; }
/* ── arch diagram ── */
#arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; }
/* ── result banner ── */
.result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; }
.result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; }
.result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; }
.result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; }
/* ── misc ── */
.gradio-html { background: transparent !important; }
.panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; }
.log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; }
"""
CUSTOM_JS = """
// elements toggle natively — JS only needed for expand-all/collapse-all/filter
window._expandAll = function() {
document.querySelectorAll('details.calc-card').forEach(d => d.open = true);
};
window._collapseAll = function() {
document.querySelectorAll('details.calc-card').forEach(d => d.open = false);
};
window._filterCards = function(cat) {
document.querySelectorAll('details.calc-card').forEach(d => {
d.style.display = (!cat || d.classList.contains('cat-'+cat)) ? '' : 'none';
});
};
// Toolbar button delegation
document.addEventListener('click', function(e) {
const btn = e.target.closest('[data-ga]');
if (!btn) return;
const a = btn.dataset.ga;
if (a === 'expand') window._expandAll();
else if (a === 'collapse') window._collapseAll();
else if (a.startsWith('filter:')) window._filterCards(a.slice(7));
}, true);
"""
# ─────────────────────────────────────────────
# Pure-SVG loss curve (no JS/canvas needed)
# ─────────────────────────────────────────────
def _loss_svg(losses):
if not losses:
return ""
W, H = 580, 200
pl, pr, pt, pb = 52, 16, 16, 36
pw, ph = W - pl - pr, H - pt - pb
mn, mx = min(losses), max(losses)
rng = mx - mn or 1
n = len(losses)
def px(i): return pl + (i / max(n - 1, 1)) * pw
def py(v): return pt + ph - ((v - mn) / rng) * ph
# Grid + Y labels
grid = ""
for k in range(5):
v = mn + (k / 4) * rng
y = py(v)
grid += f' '
grid += f'{v:.3f} '
# Polyline points
pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses))
fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}"
# X labels
xlabels = ""
for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)):
xlabels += f'E{idx+1} '
return f"""
📉 Training Loss Curve
{grid}
{xlabels}
"""
# ─────────────────────────────────────────────
# Gradio callbacks
# ─────────────────────────────────────────────
def do_train(epochs_str, progress=gr.Progress()):
global MODEL, LOSS_HISTORY, IS_TRAINED
try:
epochs = int(epochs_str)
except:
epochs = 30
losses = []
def cb(ep, total, loss):
losses.append(loss)
progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}")
MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb)
IS_TRAINED = True
chart_html = _loss_svg(LOSS_HISTORY)
return (
f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}",
chart_html
)
def do_training_viz(en_sentence, bn_sentence):
model = get_or_init_model()
if not en_sentence.strip():
return "Please enter an English sentence.
", "", ""
if not bn_sentence.strip():
bn_sentence = "আমি তোমাকে ভালোবাসি"
result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE)
# Attention heatmap (cross-attn layer 0, head 0)
meta = result.get("meta", {})
attn_html = ""
src_tokens = result.get("src_tokens", [])
tgt_tokens = result.get("tgt_tokens", [])
result_banner = f"""
English Input
"{en_sentence}"
Bengali (Teacher-forced)
{bn_sentence}
📉 Loss: {result['loss']:.4f}
"""
calc_html = f"""
Expand All
Collapse All
{"".join(f'{cat} ' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])}
show all
{calc_log_to_html(result.get('calc_log', []))}
"""
return result_banner, calc_html, attn_html
def do_inference_viz(en_sentence, decode_method):
model = get_or_init_model()
if not en_sentence.strip():
return "Please enter an English sentence.
", "", ""
result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method)
result_banner = f"""
English Input
"{en_sentence}"
Bengali Translation ({decode_method})
{result['translation'] or '(no output)'}
Tokens: {' → '.join(result['output_tokens'])}
"""
decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", []))
calc_html = f"""
Expand All
Collapse All
{calc_log_to_html(result.get('calc_log', []))}
"""
return result_banner + decode_html, calc_html, ""
# ─────────────────────────────────────────────
# Build UI
# ─────────────────────────────────────────────
def build_ui():
_theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate")
with gr.Blocks(
title="Transformer Visualizer — EN→BN",
css=CUSTOM_CSS,
js=f"() => {{ {CUSTOM_JS} }}",
theme=_theme,
) as demo:
gr.HTML("""
TRANSFORMER VISUALIZER
ENGLISH → BENGALI · EVERY CALCULATION EXPOSED
""")
with gr.Tabs():
# ── TAB 0: Architecture ──────────────────
with gr.Tab("🏗️ Architecture"):
gr.HTML(ARCH_SVG)
gr.HTML("""
📌 Encoder Flow
1. English tokens → Embedding (d_model=64)
2. + Positional Encoding (sin/cos)
3. Multi-Head Self-Attention (4 heads)
4. Add & LayerNorm
5. Feed-Forward (64→128→64)
6. Add & LayerNorm
7. Repeat × 2 layers
8. Output K, V for decoder
📌 Decoder Flow
1. Bengali tokens → Embedding
2. + Positional Encoding
3. Masked MHA (future tokens blocked)
4. Add & LayerNorm
5. Cross-Attention: Q←decoder, K,V←encoder
6. Add & LayerNorm
7. Feed-Forward
8. Linear → Softmax → Bengali token
""")
# ── TAB 1: Train ─────────────────────────
with gr.Tab("🏋️ Train Model"):
with gr.Row():
with gr.Column(scale=1):
gr.HTML('Quick Train
')
epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1)
train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary")
train_status = gr.HTML()
with gr.Column(scale=2):
loss_chart = gr.HTML()
train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart])
# ── TAB 2: Training Step Viz ──────────────
with gr.Tab("🔬 Training Step"):
gr.HTML('📚 Shows teacher-forcing : ground-truth Bengali tokens are fed to decoder, loss + gradients computed.
')
with gr.Row():
en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি")
run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary")
result_html_t = gr.HTML()
with gr.Row():
with gr.Column(scale=2):
calc_html_t = gr.HTML()
with gr.Column(scale=1):
attn_html_t = gr.HTML()
run_train_viz.click(do_training_viz,
inputs=[en_in_t, bn_in_t],
outputs=[result_html_t, calc_html_t, attn_html_t])
# ── TAB 3: Inference Viz ──────────────────
with gr.Tab("⚡ Inference"):
gr.HTML('🤖 Shows auto-regressive decoding : model generates Bengali token by token, no ground truth needed.
')
with gr.Row():
en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method")
run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary")
result_html_i = gr.HTML()
with gr.Row():
with gr.Column(scale=2):
calc_html_i = gr.HTML()
with gr.Column(scale=1):
attn_html_i = gr.HTML()
run_infer.click(do_inference_viz,
inputs=[en_in_i, decode_radio],
outputs=[result_html_i, calc_html_i, attn_html_i])
# ── TAB 4: Examples ──────────────────────
with gr.Tab("📖 Examples"):
gr.HTML("""
Try these sentences:
""" + "".join(
f'
'
for en, bn in PARALLEL_DATA[:12]
) + "
")
return demo
demo = build_ui()
demo.launch(server_name="0.0.0.0")