priyadip commited on
Commit
dc138e1
·
0 Parent(s):

Fix: js in gr.Blocks(), event delegation for card clicks, SVG loss curve

Browse files
Files changed (7) hide show
  1. README.md +91 -0
  2. app.py +800 -0
  3. inference.py +250 -0
  4. requirements.txt +2 -0
  5. training.py +261 -0
  6. transformer.py +516 -0
  7. vocab.py +146 -0
README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Transformer Visualizer EN→BN
3
+ emoji: 🔬
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 6.9.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ ---
12
+
13
+ # 🔬 Transformer Visualizer — English → Bengali
14
+
15
+ **See every single calculation inside a Transformer, live.**
16
+
17
+ ## What this Space does
18
+
19
+ Type any English sentence and watch every number flow through the Transformer architecture step by step — from raw token IDs all the way to Bengali output.
20
+
21
+ ---
22
+
23
+ ## 🗂️ Tabs
24
+
25
+ ### 🏗️ Architecture
26
+ - Full SVG diagram of encoder + decoder
27
+ - Color-coded: self-attention / cross-attention / masked attention / FFN
28
+ - Explains K,V flow from encoder to decoder
29
+
30
+ ### 🏋️ Train Model
31
+ - Trains a small Transformer on 30 English→Bengali sentence pairs
32
+ - Live loss curve rendered on canvas
33
+ - Configurable epochs
34
+
35
+ ### 🔬 Training Step
36
+ Shows a **single training forward pass** with teacher forcing:
37
+
38
+ 1. **Tokenization** — English + Bengali → token ID arrays
39
+ 2. **Embedding** — `token_id → vector × √d_model`
40
+ 3. **Positional Encoding** — `sin(pos/10000^(2i/d))` / `cos(...)` matrix shown
41
+ 4. **Encoder**:
42
+ - Q, K, V projection matrices shown
43
+ - `scores = Q·Kᵀ / √d_k` with actual numbers
44
+ - Softmax attention weights (heatmap)
45
+ - Residual + LayerNorm
46
+ - FFN: `max(0, xW₁+b₁)W₂+b₂`
47
+ 5. **Decoder**:
48
+ - Masked self-attention with causal mask matrix
49
+ - Cross-attention: Q from decoder, K/V from encoder
50
+ 6. **Loss** — label-smoothed cross-entropy, gradient norms, Adam update
51
+
52
+ ### ⚡ Inference
53
+ Shows **auto-regressive decoding**:
54
+
55
+ - No ground truth needed
56
+ - Token generated one at a time
57
+ - Top-5 candidates + probabilities at every step
58
+ - Cross-attention heatmap: which Bengali token attends to which English word
59
+ - Greedy vs Beam Search comparison
60
+
61
+ ---
62
+
63
+ ## 📁 File Structure
64
+
65
+ ```
66
+ app.py — Gradio UI + HTML/CSS/JS rendering
67
+ transformer.py — Full Transformer with CalcLog hooks
68
+ training.py — Training loop + single-step visualization
69
+ inference.py — Greedy & beam search with logging
70
+ vocab.py — English/Bengali vocabularies + parallel corpus
71
+ requirements.txt
72
+ ```
73
+
74
+ ---
75
+
76
+ ## ⚙️ Model Config
77
+
78
+ | Parameter | Value |
79
+ |-----------|-------|
80
+ | d_model | 64 |
81
+ | num_heads | 4 |
82
+ | num_layers | 2 |
83
+ | d_ff | 128 |
84
+ | vocab (EN) | ~100 |
85
+ | vocab (BN) | ~90 |
86
+ | Optimizer | Adam |
87
+ | Loss | Label-smoothed CE |
88
+
89
+ ---
90
+
91
+ *Built for educational purposes — every matrix operation is logged and displayed.*
app.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ app.py
3
+ Gradio Space: Interactive Transformer Visualizer — English → Bengali
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import json
9
+ import os
10
+ import numpy as np
11
+ from pathlib import Path
12
+
13
+ from transformer import Transformer, CalcLog
14
+ from training import build_model, run_training, visualize_training_step, collate_batch
15
+ from inference import visualize_inference
16
+ from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX
17
+
18
+ # ─────────────────────────────────────────────
19
+ # Global state
20
+ # ─────────────────────────────────────────────
21
+ DEVICE = "cpu"
22
+ src_v, tgt_v = get_vocabs()
23
+ MODEL: Transformer = None
24
+ LOSS_HISTORY = []
25
+ IS_TRAINED = False
26
+
27
+
28
+ def get_or_init_model():
29
+ global MODEL
30
+ if MODEL is None:
31
+ MODEL = build_model(len(src_v), len(tgt_v), DEVICE)
32
+ return MODEL
33
+
34
+
35
+ # ─────────────────────────────────────────────
36
+ # HTML renderer for calc log
37
+ # ─────────────────────────────────────────────
38
+
39
+ def render_matrix_html(val, max_rows=6, max_cols=8):
40
+ """Convert a nested list / scalar to an HTML matrix table."""
41
+ if isinstance(val, (int, float)):
42
+ return f'<span class="scalar-val">{val:.5f}</span>'
43
+ if isinstance(val, dict):
44
+ rows = "".join(
45
+ f'<tr><td class="dict-key">{k}</td><td class="dict-val">{v}</td></tr>'
46
+ for k, v in val.items()
47
+ )
48
+ return f'<table class="dict-table">{rows}</table>'
49
+ if isinstance(val, list):
50
+ # 0-D or scalar list
51
+ if len(val) == 0:
52
+ return "<em>empty</em>"
53
+ # 1-D
54
+ if not isinstance(val[0], list):
55
+ clipped = val[:max_cols*2]
56
+ cells = "".join(
57
+ f'<td class="mat-cell">{v:.4f}</td>'
58
+ if isinstance(v, float) else f'<td class="mat-cell">{v}</td>'
59
+ for v in clipped
60
+ )
61
+ suffix = f'<td class="mat-more">…+{len(val)-len(clipped)}</td>' if len(val) > len(clipped) else ""
62
+ return f'<table class="matrix-1d"><tr>{cells}{suffix}</tr></table>'
63
+ # 2-D
64
+ rows_html = ""
65
+ display_rows = val[:max_rows]
66
+ for row in display_rows:
67
+ display_cols = row[:max_cols]
68
+ cells = "".join(
69
+ f'<td class="mat-cell" style="--v:{min(max(float(c),-1),1):.3f}">'
70
+ f'{float(c):.3f}</td>'
71
+ if isinstance(c, (int, float)) else f'<td class="mat-cell">{c}</td>'
72
+ for c in display_cols
73
+ )
74
+ suffix = f'<td class="mat-more">…</td>' if len(row) > max_cols else ""
75
+ rows_html += f"<tr>{cells}{suffix}</tr>"
76
+ if len(val) > max_rows:
77
+ rows_html += f'<tr><td colspan="{max_cols+1}" class="mat-more">…{len(val)-max_rows} more rows</td></tr>'
78
+ return f'<table class="matrix-2d">{rows_html}</table>'
79
+ return f'<code>{str(val)[:200]}</code>'
80
+
81
+
82
+ def calc_log_to_html(steps):
83
+ """Turn CalcLog steps into rich HTML accordion."""
84
+ if not steps:
85
+ return "<p style='color:#888'>No calculation log yet.</p>"
86
+
87
+ cards = []
88
+ for i, step in enumerate(steps):
89
+ name = step.get("name", f"step_{i}")
90
+ formula = step.get("formula", "")
91
+ note = step.get("note", "")
92
+ shape = step.get("shape")
93
+ val = step.get("value")
94
+
95
+ shape_badge = f'<span class="shape-badge">{shape}</span>' if shape else ""
96
+ formula_html = f'<div class="formula">⟨ {formula} ⟩</div>' if formula else ""
97
+ note_html = f'<div class="step-note">ℹ {note}</div>' if note else ""
98
+ matrix_html = render_matrix_html(val) if val is not None else ""
99
+
100
+ # Color category by name prefix
101
+ cat = "default"
102
+ n = name.upper()
103
+ if "EMBED" in n or "TOKEN" in n: cat = "embed"
104
+ elif "PE" in n or "POSITIONAL" in n: cat = "pe"
105
+ elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn"
106
+ elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn"
107
+ elif "NORM" in n or "RESIDUAL" in n: cat = "norm"
108
+ elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss"
109
+ elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer"
110
+ elif "CROSS" in n: cat = "cross"
111
+ elif "MASK" in n: cat = "mask"
112
+
113
+ cards.append(f"""
114
+ <div class="calc-card cat-{cat}" data-idx="{i}">
115
+ <div class="calc-header">
116
+ <span class="step-num">#{i+1}</span>
117
+ <span class="step-name cat-label-{cat}">{name.replace('_',' ')}</span>
118
+ {shape_badge}
119
+ <span class="toggle-arrow">▶</span>
120
+ </div>
121
+ <div class="calc-body" style="display:none">
122
+ {formula_html}
123
+ {note_html}
124
+ <div class="matrix-wrap">{matrix_html}</div>
125
+ </div>
126
+ </div>""")
127
+
128
+ return "\n".join(cards)
129
+
130
+
131
+ # ─────────────────────────────────────────────
132
+ # Attention heatmap HTML
133
+ # ─────────────────────────────────────────────
134
+
135
+ def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"):
136
+ """weights: 2D list [tgt, src]"""
137
+ if not weights:
138
+ return ""
139
+ rows_html = ""
140
+ for i, row in enumerate(weights):
141
+ cells = ""
142
+ for j, w in enumerate(row):
143
+ alpha = min(float(w), 1.0)
144
+ cells += f'<td class="heat-cell" style="--a:{alpha:.3f}" title="{row_labels[i] if i<len(row_labels) else i}→{col_labels[j] if j<len(col_labels) else j}: {alpha:.3f}">{alpha:.2f}</td>'
145
+ lbl = row_labels[i] if i < len(row_labels) else str(i)
146
+ rows_html += f'<tr><td class="heat-label">{lbl}</td>{cells}</tr>'
147
+ header = '<tr><td></td>' + "".join(f'<td class="heat-col-label">{c}</td>' for c in col_labels) + '</tr>'
148
+ return f"""
149
+ <div class="heatmap-container">
150
+ <div class="heatmap-title">{title}</div>
151
+ <table class="heatmap">{header}{rows_html}</table>
152
+ </div>"""
153
+
154
+
155
+ # ─────────────────────────────────────────────
156
+ # Decoding steps HTML
157
+ # ─────────────────────────────────────────────
158
+
159
+ def decode_steps_html(step_logs, src_tokens):
160
+ if not step_logs:
161
+ return ""
162
+ html = '<div class="decode-steps"><div class="decode-title">🔁 Auto-regressive Decoding Steps</div>'
163
+ for s in step_logs:
164
+ step = s.get("step", 0)
165
+ tokens_so_far = s.get("tokens_so_far", [])
166
+ top5 = s.get("top5", [])
167
+ chosen = s.get("chosen_token", "?")
168
+ prob = s.get("chosen_prob", 0)
169
+
170
+ bars = ""
171
+ if top5:
172
+ max_p = max(t["prob"] for t in top5) or 1
173
+ for t in top5:
174
+ pct = t["prob"] / max_p * 100
175
+ is_chosen = "chosen" if t["token"] == chosen else ""
176
+ bars += f"""<div class="bar-row {is_chosen}">
177
+ <span class="bar-label">{t['token']}</span>
178
+ <div class="bar" style="width:{pct:.1f}%"></div>
179
+ <span class="bar-prob">{t['prob']:.3f}</span>
180
+ </div>"""
181
+
182
+ cross_heat = ""
183
+ if s.get("cross_attn") and src_tokens:
184
+ attn_mat = s["cross_attn"] # [num_heads][T_q][T_src]
185
+ if attn_mat and attn_mat[0]:
186
+ # Take head-0, last decoded position → [T_src] floats
187
+ last_pos_attn = attn_mat[0][-1] # [T_src]
188
+ last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap
189
+ cross_heat = attention_heatmap_html(
190
+ last_row, [chosen], src_tokens,
191
+ title=f"Cross-Attn: '{chosen}' → English"
192
+ )
193
+
194
+ html += f"""
195
+ <div class="decode-step">
196
+ <div class="decode-step-header">
197
+ <span class="step-badge">Step {step+1}</span>
198
+ <span class="step-ctx">Context: {' '.join(tokens_so_far)}</span>
199
+ <span class="step-arrow">→</span>
200
+ <span class="step-chosen">'{chosen}'</span>
201
+ <span class="step-prob">{prob:.3f}</span>
202
+ </div>
203
+ <div class="step-bars">{bars}</div>
204
+ {cross_heat}
205
+ </div>"""
206
+ html += "</div>"
207
+ return html
208
+
209
+
210
+ # ─────────────────────────────────────────────
211
+ # Architecture SVG
212
+ # ─────────────────────────────────────────────
213
+
214
+ ARCH_SVG = """
215
+ <div id="arch-diagram">
216
+ <svg viewBox="0 0 820 900" xmlns="http://www.w3.org/2000/svg" style="width:100%;max-width:820px;margin:auto;display:block">
217
+ <defs>
218
+ <marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
219
+ <path d="M0,0 L0,6 L8,3 z" fill="#64ffda"/>
220
+ </marker>
221
+ <filter id="glow">
222
+ <feGaussianBlur stdDeviation="2" result="blur"/>
223
+ <feMerge><feMergeNode in="blur"/><feMergeNode in="SourceGraphic"/></feMerge>
224
+ </filter>
225
+ </defs>
226
+
227
+ <!-- Background -->
228
+ <rect width="820" height="900" fill="#0a0f1e" rx="12"/>
229
+
230
+ <!-- Title -->
231
+ <text x="410" y="35" text-anchor="middle" fill="#64ffda" font-size="16" font-family="monospace" font-weight="bold">Transformer Architecture — English → Bengali</text>
232
+
233
+ <!-- ── ENCODER (left) ── -->
234
+ <rect x="40" y="60" width="330" height="720" rx="10" fill="#0d1b2a" stroke="#1e4d6b" stroke-width="1.5"/>
235
+ <text x="205" y="90" text-anchor="middle" fill="#4fc3f7" font-size="13" font-weight="bold">ENCODER</text>
236
+
237
+ <!-- Input Embedding -->
238
+ <rect x="70" y="110" width="270" height="40" rx="6" fill="#1a3a5c" stroke="#4fc3f7" stroke-width="1.5"/>
239
+ <text x="205" y="135" text-anchor="middle" fill="#e0f7fa" font-size="11">Input Embedding + Positional Encoding</text>
240
+
241
+ <!-- Encoder Layer Box -->
242
+ <rect x="60" y="175" width="290" height="340" rx="8" fill="#112233" stroke="#1e4d6b" stroke-width="1" stroke-dasharray="4"/>
243
+ <text x="100" y="198" fill="#607d8b" font-size="10">Encoder Layer × N</text>
244
+
245
+ <!-- Multi-Head Self-Attention -->
246
+ <rect x="80" y="210" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
247
+ <text x="205" y="232" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Multi-Head Self-Attention</text>
248
+ <text x="205" y="248" text-anchor="middle" fill="#80deea" font-size="9">Q = K = V = encoder input</text>
249
+
250
+ <!-- Add & Norm 1 -->
251
+ <rect x="80" y="278" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
252
+ <text x="205" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
253
+
254
+ <!-- FFN -->
255
+ <rect x="80" y="328" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
256
+ <text x="205" y="350" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Feed-Forward Network</text>
257
+ <text x="205" y="366" text-anchor="middle" fill="#80deea" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
258
+
259
+ <!-- Add & Norm 2 -->
260
+ <rect x="80" y="396" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
261
+ <text x="205" y="416" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
262
+
263
+ <!-- Encoder output arrow down -->
264
+ <line x1="205" y1="455" x2="205" y2="550" stroke="#64ffda" stroke-width="1.5" marker-end="url(#arr)"/>
265
+ <text x="215" y="510" fill="#64ffda" font-size="9">K, V to</text>
266
+ <text x="215" y="522" fill="#64ffda" font-size="9">decoder</text>
267
+
268
+ <!-- Encoder output box -->
269
+ <rect x="70" y="555" width="270" height="40" rx="6" fill="#0d2b1a" stroke="#00e676" stroke-width="1.5"/>
270
+ <text x="205" y="580" text-anchor="middle" fill="#a5d6a7" font-size="11">Encoder Output (K, V)</text>
271
+
272
+ <!-- ── DECODER (right) ── -->
273
+ <rect x="450" y="60" width="330" height="720" rx="10" fill="#1a0d2a" stroke="#4a1b6b" stroke-width="1.5"/>
274
+ <text x="615" y="90" text-anchor="middle" fill="#ce93d8" font-size="13" font-weight="bold">DECODER</text>
275
+
276
+ <!-- Target Embedding -->
277
+ <rect x="480" y="110" width="270" height="40" rx="6" fill="#3a1a5c" stroke="#ce93d8" stroke-width="1.5"/>
278
+ <text x="615" y="135" text-anchor="middle" fill="#f3e5f5" font-size="11">Target Embedding + Positional Encoding</text>
279
+
280
+ <!-- Decoder Layer Box -->
281
+ <rect x="470" y="175" width="290" height="460" rx="8" fill="#1a1133" stroke="#4a1b6b" stroke-width="1" stroke-dasharray="4"/>
282
+ <text x="510" y="198" fill="#607d8b" font-size="10">Decoder Layer × N</text>
283
+
284
+ <!-- Masked MHA -->
285
+ <rect x="490" y="210" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
286
+ <text x="615" y="232" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Masked Multi-Head Self-Attention</text>
287
+ <text x="615" y="248" text-anchor="middle" fill="#ce93d8" font-size="9">Q = K = V = decoder input (causal mask)</text>
288
+
289
+ <!-- Add & Norm D1 -->
290
+ <rect x="490" y="278" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
291
+ <text x="615" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
292
+
293
+ <!-- Cross-Attention -->
294
+ <rect x="490" y="328" width="250" height="60" rx="6" fill="#1b2b4b" stroke="#29b6f6" stroke-width="2" filter="url(#glow)"/>
295
+ <text x="615" y="350" text-anchor="middle" fill="#e1f5fe" font-size="11" font-weight="bold">Cross-Attention</text>
296
+ <text x="615" y="366" text-anchor="middle" fill="#81d4fa" font-size="9">Q = decoder | K, V = encoder</text>
297
+ <text x="615" y="380" text-anchor="middle" fill="#29b6f6" font-size="9" font-weight="bold">← KEY CONNECTION</text>
298
+
299
+ <!-- Add & Norm D2 -->
300
+ <rect x="490" y="408" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
301
+ <text x="615" y="428" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
302
+
303
+ <!-- FFN Decoder -->
304
+ <rect x="490" y="458" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
305
+ <text x="615" y="480" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Feed-Forward Network</text>
306
+ <text x="615" y="496" text-anchor="middle" fill="#ce93d8" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
307
+
308
+ <!-- Add & Norm D3 -->
309
+ <rect x="490" y="526" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
310
+ <text x="615" y="546" text-anchor="middle" fill="#b0bec5" font-size="10">Add &amp; Norm</text>
311
+
312
+ <!-- Output Linear + Softmax -->
313
+ <rect x="480" y="600" width="270" height="40" rx="6" fill="#2b1b0a" stroke="#ffb300" stroke-width="1.5"/>
314
+ <text x="615" y="625" text-anchor="middle" fill="#fff8e1" font-size="11">Linear + Softmax → Bengali Token</text>
315
+
316
+ <!-- Cross-attention arrow from encoder to decoder -->
317
+ <path d="M340,590 Q410,480 490,368" stroke="#29b6f6" stroke-width="2" fill="none"
318
+ stroke-dasharray="6,3" marker-end="url(#arr)"/>
319
+ <text x="390" y="500" fill="#29b6f6" font-size="9" transform="rotate(-50,390,500)">K, V flow</text>
320
+
321
+ <!-- Input arrow -->
322
+ <line x1="205" y1="840" x2="205" y2="780" stroke="#4fc3f7" stroke-width="1.5" marker-end="url(#arr)"/>
323
+ <text x="205" y="858" text-anchor="middle" fill="#4fc3f7" font-size="11">English Input</text>
324
+
325
+ <line x1="615" y1="840" x2="615" y2="660" stroke="#ce93d8" stroke-width="1.5" marker-end="url(#arr)"/>
326
+ <text x="615" y="858" text-anchor="middle" fill="#ce93d8" font-size="11">Bengali Output</text>
327
+
328
+ <!-- Legend -->
329
+ <rect x="60" y="870" width="700" height="20" rx="4" fill="#0a1520" stroke="#1e2d3d" stroke-width="1"/>
330
+ <circle cx="80" cy="880" r="4" fill="#26c6da"/><text x="88" y="884" fill="#80deea" font-size="8">Self-Attention</text>
331
+ <circle cx="160" cy="880" r="4" fill="#29b6f6"/><text x="168" y="884" fill="#81d4fa" font-size="8">Cross-Attention</text>
332
+ <circle cx="250" cy="880" r="4" fill="#ab47bc"/><text x="258" y="884" fill="#ce93d8" font-size="8">Masked Attn</text>
333
+ <circle cx="350" cy="880" r="4" fill="#00e676"/><text x="358" y="884" fill="#a5d6a7" font-size="8">Enc→Dec K,V</text>
334
+ <circle cx="450" cy="880" r="4" fill="#ffb300"/><text x="458" y="884" fill="#fff8e1" font-size="8">Output Layer</text>
335
+ </svg>
336
+ </div>
337
+ """
338
+
339
+
340
+ # ─────────────────────────────────────────────
341
+ # CSS + JS
342
+ # ─────────────────────────────────────────────
343
+
344
+ CUSTOM_CSS = """
345
+ /* ── fonts ── */
346
+ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap');
347
+
348
+ :root {
349
+ --bg: #07090f;
350
+ --bg2: #0d1120;
351
+ --bg3: #111827;
352
+ --card: #141c2e;
353
+ --border: #1e2d45;
354
+ --accent: #64ffda;
355
+ --accent2: #29b6f6;
356
+ --accent3: #ce93d8;
357
+ --accent4: #ffb300;
358
+ --text: #e2e8f0;
359
+ --muted: #64748b;
360
+ --embed: #4fc3f7;
361
+ --pe: #26c6da;
362
+ --attn: #f06292;
363
+ --ffn: #aed581;
364
+ --norm: #90a4ae;
365
+ --loss: #ef9a9a;
366
+ --infer: #80cbc4;
367
+ --cross: #29b6f6;
368
+ --mask: #ffb300;
369
+ }
370
+
371
+ body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; }
372
+
373
+ h1, h2, h3 { font-family: 'Syne', sans-serif !important; }
374
+
375
+ /* ── tabs ── */
376
+ .tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; }
377
+ .tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); }
378
+
379
+ /* ── inputs ── */
380
+ input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; }
381
+ input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; }
382
+
383
+ button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; }
384
+ button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; }
385
+
386
+ /* ── calc cards ── */
387
+ .calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; }
388
+ .calc-header { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; transition: background 0.15s; }
389
+ .calc-header:hover { background: rgba(255,255,255,0.03); }
390
+ .calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); }
391
+ .step-num { color: var(--muted); font-size: 11px; min-width: 28px; }
392
+ .step-name { font-weight: 600; font-size: 12px; flex: 1; }
393
+ .toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; }
394
+ .toggle-arrow.open { transform: rotate(90deg); }
395
+ .shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); }
396
+ .formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); }
397
+ .step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; }
398
+
399
+ /* category colors */
400
+ .cat-label-embed { color: var(--embed); }
401
+ .cat-label-pe { color: var(--pe); }
402
+ .cat-label-attn { color: var(--attn); }
403
+ .cat-label-ffn { color: var(--ffn); }
404
+ .cat-label-norm { color: var(--norm); }
405
+ .cat-label-loss { color: var(--loss); }
406
+ .cat-label-infer { color: var(--infer); }
407
+ .cat-label-cross { color: var(--cross); }
408
+ .cat-label-mask { color: var(--mask); }
409
+ .cat-label-default{ color: var(--text); }
410
+
411
+ .cat-embed { border-left: 3px solid var(--embed); }
412
+ .cat-pe { border-left: 3px solid var(--pe); }
413
+ .cat-attn { border-left: 3px solid var(--attn); }
414
+ .cat-ffn { border-left: 3px solid var(--ffn); }
415
+ .cat-norm { border-left: 3px solid var(--norm); }
416
+ .cat-loss { border-left: 3px solid var(--loss); }
417
+ .cat-infer { border-left: 3px solid var(--infer); }
418
+ .cat-cross { border-left: 3px solid var(--cross); }
419
+ .cat-mask { border-left: 3px solid var(--mask); }
420
+ .cat-default{ border-left: 3px solid var(--border); }
421
+
422
+ /* ── matrix tables ── */
423
+ .matrix-wrap { overflow-x: auto; }
424
+ .matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; }
425
+ .mat-cell {
426
+ padding: 2px 5px; text-align: right; min-width: 48px;
427
+ background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%));
428
+ color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05);
429
+ }
430
+ .mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; }
431
+ .dict-table { font-size: 11px; width: 100%; }
432
+ .dict-key { color: var(--accent); padding: 2px 8px 2px 0; }
433
+ .dict-val { color: var(--text); padding: 2px; }
434
+ .scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; }
435
+
436
+ /* ── heatmap ── */
437
+ .heatmap-container { margin: 8px 0; }
438
+ .heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; }
439
+ .heatmap { border-collapse: collapse; font-size: 10px; }
440
+ .heat-cell {
441
+ width: 36px; height: 24px; text-align: center;
442
+ background: rgba(41, 182, 246, calc(var(--a, 0)));
443
+ border: 1px solid rgba(255,255,255,0.04);
444
+ color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%));
445
+ font-size: 9px; cursor: default;
446
+ }
447
+ .heat-cell:hover { outline: 1px solid var(--accent); }
448
+ .heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; }
449
+ .heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; }
450
+
451
+ /* ── decode steps ── */
452
+ .decode-steps { margin-top: 12px; }
453
+ .decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); }
454
+ .decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); }
455
+ .decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; }
456
+ .step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; }
457
+ .step-ctx { color: var(--muted); font-size: 11px; }
458
+ .step-arrow { color: var(--accent4); }
459
+ .step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; }
460
+ .step-prob { color: var(--accent4); font-size: 11px; }
461
+ .step-bars { margin: 4px 0; }
462
+ .bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; }
463
+ .bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; }
464
+ .bar-row.chosen .bar { background: var(--accent3) !important; }
465
+ .bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
466
+ .bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; }
467
+ .bar-prob { font-size: 10px; color: var(--muted); }
468
+
469
+ /* ── loss chart ── */
470
+ #loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; }
471
+
472
+ /* ── arch diagram ── */
473
+ #arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; }
474
+
475
+ /* ── result banner ── */
476
+ .result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; }
477
+ .result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; }
478
+ .result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; }
479
+ .result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; }
480
+
481
+ /* ── misc ── */
482
+ .gradio-html { background: transparent !important; }
483
+ .panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; }
484
+ .log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; }
485
+ """
486
+
487
+ CUSTOM_JS = """
488
+ // Card toggle
489
+ window._toggleCard = function(header) {
490
+ const body = header.nextElementSibling;
491
+ const arrow = header.querySelector('.toggle-arrow');
492
+ if (!body) return;
493
+ const open = body.style.display === 'block';
494
+ body.style.display = open ? 'none' : 'block';
495
+ if (arrow) arrow.classList.toggle('open', !open);
496
+ };
497
+ window._expandAll = function() {
498
+ document.querySelectorAll('.calc-body').forEach(b => b.style.display='block');
499
+ document.querySelectorAll('.toggle-arrow').forEach(a => a.classList.add('open'));
500
+ };
501
+ window._collapseAll = function() {
502
+ document.querySelectorAll('.calc-body').forEach(b => b.style.display='none');
503
+ document.querySelectorAll('.toggle-arrow').forEach(a => a.classList.remove('open'));
504
+ };
505
+ window._filterCards = function(cat) {
506
+ document.querySelectorAll('.calc-card').forEach(c => {
507
+ c.style.display = (!cat || c.classList.contains('cat-'+cat)) ? '' : 'none';
508
+ });
509
+ };
510
+ // Event delegation — works even if Gradio strips onclick attrs
511
+ document.addEventListener('click', function(e) {
512
+ const header = e.target.closest('.calc-header');
513
+ if (header) { window._toggleCard(header); return; }
514
+ const btn = e.target.closest('[data-ga]');
515
+ if (btn) {
516
+ const a = btn.dataset.ga;
517
+ if (a === 'expand') window._expandAll();
518
+ else if (a === 'collapse') window._collapseAll();
519
+ else if (a.startsWith('filter:')) window._filterCards(a.slice(7));
520
+ }
521
+ }, true);
522
+ """
523
+
524
+ # ─────────────────────────────────────────────
525
+ # Pure-SVG loss curve (no JS/canvas needed)
526
+ # ─────────────────────────────────────────────
527
+
528
+ def _loss_svg(losses):
529
+ if not losses:
530
+ return ""
531
+ W, H = 580, 200
532
+ pl, pr, pt, pb = 52, 16, 16, 36
533
+ pw, ph = W - pl - pr, H - pt - pb
534
+ mn, mx = min(losses), max(losses)
535
+ rng = mx - mn or 1
536
+ n = len(losses)
537
+
538
+ def px(i): return pl + (i / max(n - 1, 1)) * pw
539
+ def py(v): return pt + ph - ((v - mn) / rng) * ph
540
+
541
+ # Grid + Y labels
542
+ grid = ""
543
+ for k in range(5):
544
+ v = mn + (k / 4) * rng
545
+ y = py(v)
546
+ grid += f'<line x1="{pl}" y1="{y:.1f}" x2="{pl+pw}" y2="{y:.1f}" stroke="#1e2d45" stroke-width="0.5"/>'
547
+ grid += f'<text x="{pl-4}" y="{y+4:.1f}" text-anchor="end" fill="#64748b" font-size="9" font-family="monospace">{v:.3f}</text>'
548
+
549
+ # Polyline points
550
+ pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses))
551
+ fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}"
552
+
553
+ # X labels
554
+ xlabels = ""
555
+ for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)):
556
+ xlabels += f'<text x="{px(idx):.1f}" y="{H-4}" text-anchor="middle" fill="#64748b" font-size="9" font-family="monospace">E{idx+1}</text>'
557
+
558
+ return f"""
559
+ <div style="background:#0d1120;border:1px solid #1e2d45;border-radius:8px;padding:12px;margin-top:8px">
560
+ <div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">📉 Training Loss Curve</div>
561
+ <svg width="{W}" height="{H}" style="display:block;max-width:100%">
562
+ <defs>
563
+ <linearGradient id="lcg" x1="0" y1="0" x2="{W}" y2="0" gradientUnits="userSpaceOnUse">
564
+ <stop offset="0%" stop-color="#64ffda"/><stop offset="100%" stop-color="#29b6f6"/>
565
+ </linearGradient>
566
+ </defs>
567
+ <rect width="{W}" height="{H}" fill="#0d1120"/>
568
+ {grid}
569
+ <polygon points="{fill_pts}" fill="rgba(100,255,218,0.08)"/>
570
+ <polyline points="{pts}" fill="none" stroke="url(#lcg)" stroke-width="2.5" stroke-linejoin="round"/>
571
+ {xlabels}
572
+ </svg>
573
+ </div>"""
574
+
575
+
576
+ # ─────────────────────────────────────────────
577
+ # Gradio callbacks
578
+ # ─────────────────────────────────────────────
579
+
580
+ def do_train(epochs_str, progress=gr.Progress()):
581
+ global MODEL, LOSS_HISTORY, IS_TRAINED
582
+ try:
583
+ epochs = int(epochs_str)
584
+ except:
585
+ epochs = 30
586
+
587
+ losses = []
588
+ def cb(ep, total, loss):
589
+ losses.append(loss)
590
+ progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}")
591
+
592
+ MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb)
593
+ IS_TRAINED = True
594
+
595
+ chart_html = _loss_svg(LOSS_HISTORY)
596
+ return (
597
+ f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}",
598
+ chart_html
599
+ )
600
+
601
+
602
+ def do_training_viz(en_sentence, bn_sentence):
603
+ model = get_or_init_model()
604
+ if not en_sentence.strip():
605
+ return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
606
+ if not bn_sentence.strip():
607
+ bn_sentence = "আমি তোমাকে ভালোবাসি"
608
+
609
+ result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE)
610
+
611
+ # Attention heatmap (cross-attn layer 0, head 0)
612
+ meta = result.get("meta", {})
613
+ attn_html = ""
614
+ src_tokens = result.get("src_tokens", [])
615
+ tgt_tokens = result.get("tgt_tokens", [])
616
+
617
+ result_banner = f"""
618
+ <div class="result-banner">
619
+ <div class="result-label">English Input</div>
620
+ <div class="result-en">"{en_sentence}"</div>
621
+ <div class="result-label" style="margin-top:8px">Bengali (Teacher-forced)</div>
622
+ <div class="result-bn">{bn_sentence}</div>
623
+ <div style="margin-top:8px;color:var(--loss);font-size:13px">
624
+ 📉 Loss: <strong>{result['loss']:.4f}</strong>
625
+ </div>
626
+ </div>"""
627
+
628
+ calc_html = f"""
629
+ <div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
630
+ <button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
631
+ <button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
632
+ {"".join(f'<button data-ga="filter:{cat}" style="background:var(--card);color:var(--cat-{cat},var(--text));border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">{cat}</button>' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])}
633
+ <button data-ga="filter:" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">show all</button>
634
+ </div>
635
+ <div class="log-container">
636
+ {calc_log_to_html(result.get('calc_log', []))}
637
+ </div>"""
638
+
639
+ return result_banner, calc_html, attn_html
640
+
641
+
642
+ def do_inference_viz(en_sentence, decode_method):
643
+ model = get_or_init_model()
644
+ if not en_sentence.strip():
645
+ return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
646
+
647
+ result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method)
648
+
649
+ result_banner = f"""
650
+ <div class="result-banner">
651
+ <div class="result-label">English Input</div>
652
+ <div class="result-en">"{en_sentence}"</div>
653
+ <div class="result-label" style="margin-top:8px">Bengali Translation ({decode_method})</div>
654
+ <div class="result-bn">{result['translation'] or '(no output)'}</div>
655
+ <div style="margin-top:6px;color:var(--muted);font-size:11px">
656
+ Tokens: {' → '.join(result['output_tokens'])}
657
+ </div>
658
+ </div>"""
659
+
660
+ decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", []))
661
+
662
+ calc_html = f"""
663
+ <div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
664
+ <button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
665
+ <button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
666
+ </div>
667
+ <div class="log-container">
668
+ {calc_log_to_html(result.get('calc_log', []))}
669
+ </div>"""
670
+
671
+ return result_banner + decode_html, calc_html, ""
672
+
673
+
674
+ # ─────────────────────────────────────────────
675
+ # Build UI
676
+ # ─────────────────────────────────────────────
677
+
678
+ def build_ui():
679
+ _theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate")
680
+ with gr.Blocks(
681
+ title="Transformer Visualizer — EN→BN",
682
+ css=CUSTOM_CSS,
683
+ js=f"() => {{ {CUSTOM_JS} }}",
684
+ theme=_theme,
685
+ ) as demo:
686
+
687
+ gr.HTML("""
688
+ <div style="text-align:center;padding:24px 0 12px;border-bottom:1px solid #1e2d45;margin-bottom:16px">
689
+ <div style="font-family:'Syne',sans-serif;font-size:28px;font-weight:800;
690
+ background:linear-gradient(135deg,#64ffda,#29b6f6,#ce93d8);
691
+ -webkit-background-clip:text;-webkit-text-fill-color:transparent;letter-spacing:2px">
692
+ TRANSFORMER VISUALIZER
693
+ </div>
694
+ <div style="color:#64748b;font-size:12px;letter-spacing:3px;margin-top:4px;font-family:'JetBrains Mono',monospace">
695
+ ENGLISH → BENGALI · EVERY CALCULATION EXPOSED
696
+ </div>
697
+ </div>
698
+ """)
699
+
700
+ with gr.Tabs():
701
+
702
+ # ── TAB 0: Architecture ──────────────────
703
+ with gr.Tab("🏗️ Architecture"):
704
+ gr.HTML(ARCH_SVG)
705
+ gr.HTML("""
706
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;margin-top:12px">
707
+ <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
708
+ <div style="color:#4fc3f7;font-weight:700;margin-bottom:8px">📌 Encoder Flow</div>
709
+ <div style="color:#94a3b8;font-size:12px;line-height:1.8">
710
+ 1. English tokens → Embedding (d_model=64)<br>
711
+ 2. + Positional Encoding (sin/cos)<br>
712
+ 3. Multi-Head Self-Attention (4 heads)<br>
713
+ 4. Add &amp; LayerNorm<br>
714
+ 5. Feed-Forward (64→128→64)<br>
715
+ 6. Add &amp; LayerNorm<br>
716
+ 7. Repeat × 2 layers<br>
717
+ 8. Output K, V for decoder
718
+ </div>
719
+ </div>
720
+ <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
721
+ <div style="color:#ce93d8;font-weight:700;margin-bottom:8px">📌 Decoder Flow</div>
722
+ <div style="color:#94a3b8;font-size:12px;line-height:1.8">
723
+ 1. Bengali tokens → Embedding<br>
724
+ 2. + Positional Encoding<br>
725
+ 3. Masked MHA (future tokens blocked)<br>
726
+ 4. Add &amp; LayerNorm<br>
727
+ 5. Cross-Attention: Q←decoder, K,V←encoder<br>
728
+ 6. Add &amp; LayerNorm<br>
729
+ 7. Feed-Forward<br>
730
+ 8. Linear → Softmax → Bengali token
731
+ </div>
732
+ </div>
733
+ </div>
734
+ """)
735
+
736
+ # ── TAB 1: Train ─────────────────────────
737
+ with gr.Tab("🏋️ Train Model"):
738
+ with gr.Row():
739
+ with gr.Column(scale=1):
740
+ gr.HTML('<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">Quick Train</div>')
741
+ epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1)
742
+ train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary")
743
+ train_status = gr.HTML()
744
+ with gr.Column(scale=2):
745
+ loss_chart = gr.HTML()
746
+ train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart])
747
+
748
+ # ── TAB 2: Training Step Viz ──────────────
749
+ with gr.Tab("🔬 Training Step"):
750
+ gr.HTML('<div style="color:#ef9a9a;font-size:12px;margin-bottom:12px">📚 Shows <strong>teacher-forcing</strong>: ground-truth Bengali tokens are fed to decoder, loss + gradients computed.</div>')
751
+ with gr.Row():
752
+ en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
753
+ bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি")
754
+ run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary")
755
+ result_html_t = gr.HTML()
756
+ with gr.Row():
757
+ with gr.Column(scale=2):
758
+ calc_html_t = gr.HTML()
759
+ with gr.Column(scale=1):
760
+ attn_html_t = gr.HTML()
761
+ run_train_viz.click(do_training_viz,
762
+ inputs=[en_in_t, bn_in_t],
763
+ outputs=[result_html_t, calc_html_t, attn_html_t])
764
+
765
+ # ── TAB 3: Inference Viz ──────────────────
766
+ with gr.Tab("⚡ Inference"):
767
+ gr.HTML('<div style="color:#80cbc4;font-size:12px;margin-bottom:12px">🤖 Shows <strong>auto-regressive decoding</strong>: model generates Bengali token by token, no ground truth needed.</div>')
768
+ with gr.Row():
769
+ en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
770
+ decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method")
771
+ run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary")
772
+ result_html_i = gr.HTML()
773
+ with gr.Row():
774
+ with gr.Column(scale=2):
775
+ calc_html_i = gr.HTML()
776
+ with gr.Column(scale=1):
777
+ attn_html_i = gr.HTML()
778
+ run_infer.click(do_inference_viz,
779
+ inputs=[en_in_i, decode_radio],
780
+ outputs=[result_html_i, calc_html_i, attn_html_i])
781
+
782
+ # ── TAB 4: Examples ──────────────────────
783
+ with gr.Tab("📖 Examples"):
784
+ gr.HTML("""
785
+ <div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:16px">
786
+ <div style="color:#64ffda;font-weight:700;margin-bottom:12px">Try these sentences:</div>
787
+ <div style="display:grid;grid-template-columns:1fr 1fr;gap:8px">
788
+ """ + "".join(
789
+ f'<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:6px;padding:8px">'
790
+ f'<div style="color:#4fc3f7;font-size:12px">{en}</div>'
791
+ f'<div style="color:#ce93d8;font-size:13px;font-weight:600">{bn}</div>'
792
+ f'</div>'
793
+ for en, bn in PARALLEL_DATA[:12]
794
+ ) + "</div></div>")
795
+
796
+ return demo
797
+
798
+
799
+ demo = build_ui()
800
+ demo.launch(server_name="0.0.0.0")
inference.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py
3
+ Inference (translation) for English→Bengali with full calculation logging.
4
+ Supports greedy decoding and beam search, showing every step.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ import math
11
+ from typing import Dict, List, Tuple, Optional
12
+
13
+ from transformer import Transformer, CalcLog
14
+ from vocab import get_vocabs, PAD_IDX, BOS_IDX, EOS_IDX
15
+
16
+
17
+ # ─────────────────────────────────────────────
18
+ # Greedy decoding with full logging
19
+ # ─────────────────────────────────────────────
20
+
21
+ def greedy_decode(
22
+ model: Transformer,
23
+ src: torch.Tensor,
24
+ max_len: int = 20,
25
+ device: str = "cpu",
26
+ log: Optional[CalcLog] = None,
27
+ ) -> Tuple[List[int], List[Dict]]:
28
+ model.eval()
29
+ src_v, tgt_v = get_vocabs()
30
+
31
+ with torch.no_grad():
32
+ src_mask = model.make_src_mask(src)
33
+
34
+ # ── Encode once ──────────────────────
35
+ src_emb = model.src_embed(src) * math.sqrt(model.d_model)
36
+ enc_x = model.src_pe(src_emb, log=log)
37
+
38
+ enc_attn_weights = []
39
+ for i, layer in enumerate(model.encoder_layers):
40
+ enc_x, ew = layer(enc_x, src_mask=src_mask,
41
+ log=log if i == 0 else None, layer_idx=i)
42
+ enc_attn_weights.append(ew.cpu().numpy())
43
+
44
+ if log:
45
+ log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8],
46
+ note="Encoder finished. Output K,V will be reused for every decoder step.")
47
+
48
+ # ── Auto-regressive decode ────────────
49
+ generated = [BOS_IDX]
50
+ step_logs = []
51
+
52
+ for step in range(max_len):
53
+ tgt_so_far = torch.tensor([generated], dtype=torch.long, device=device)
54
+ tgt_mask = model.make_tgt_mask(tgt_so_far)
55
+
56
+ tgt_emb = model.tgt_embed(tgt_so_far) * math.sqrt(model.d_model)
57
+ dec_x = model.tgt_pe(tgt_emb)
58
+
59
+ step_dec_cross = []
60
+ for i, layer in enumerate(model.decoder_layers):
61
+ do_log = (log is not None) and (step < 3) and (i == 0)
62
+ if do_log:
63
+ log.log(f"INFERENCE_step{step}_dec_input", dec_x[0, :, :8],
64
+ note=f"Decoder input at step {step}: tokens so far = "
65
+ f"{tgt_v.tokens(generated)}")
66
+ dec_x, mw, cw = layer(
67
+ dec_x, enc_x,
68
+ tgt_mask=tgt_mask, src_mask=src_mask,
69
+ log=log if do_log else None,
70
+ layer_idx=i,
71
+ )
72
+ step_dec_cross.append(cw.cpu().numpy())
73
+
74
+ # Only look at last position
75
+ last_logits = model.output_linear(dec_x[:, -1, :]) # (1, V)
76
+ probs = F.softmax(last_logits, dim=-1)[0]
77
+
78
+ # Top-5 predictions
79
+ top5_probs, top5_ids = probs.topk(5)
80
+ top5 = [
81
+ {"token": tgt_v.idx2token.get(idx.item(), "?"),
82
+ "id": idx.item(),
83
+ "prob": round(prob.item(), 4)}
84
+ for prob, idx in zip(top5_probs, top5_ids)
85
+ ]
86
+
87
+ # Greedy: pick highest
88
+ next_token = top5_ids[0].item()
89
+
90
+ step_info = {
91
+ "step": step,
92
+ "tokens_so_far": tgt_v.tokens(generated),
93
+ "top5": top5,
94
+ "chosen_token": tgt_v.idx2token.get(next_token, "?"),
95
+ "chosen_id": next_token,
96
+ "chosen_prob": round(top5_probs[0].item(), 4),
97
+ "cross_attn": step_dec_cross[0][0].tolist()
98
+ if step_dec_cross else None,
99
+ }
100
+ step_logs.append(step_info)
101
+
102
+ if log and step < 3:
103
+ log.log(f"INFERENCE_step{step}_top5", top5,
104
+ formula="P(next_token) = softmax(W_out · dec_out[-1])",
105
+ note=f"Step {step}: top-5 candidates. Chosen: {step_info['chosen_token']} ({step_info['chosen_prob']:.4f})")
106
+
107
+ generated.append(next_token)
108
+ if next_token == EOS_IDX:
109
+ break
110
+
111
+ return generated, step_logs
112
+
113
+
114
+ # ─────────────────────────────────────────────
115
+ # Beam search
116
+ # ─────────────────────────────────────────────
117
+
118
+ def beam_search(
119
+ model: Transformer,
120
+ src: torch.Tensor,
121
+ beam_size: int = 3,
122
+ max_len: int = 20,
123
+ device: str = "cpu",
124
+ log: Optional[CalcLog] = None,
125
+ ) -> Tuple[List[int], List[Dict]]:
126
+ model.eval()
127
+ src_v, tgt_v = get_vocabs()
128
+
129
+ with torch.no_grad():
130
+ src_mask = model.make_src_mask(src)
131
+
132
+ # Encode
133
+ src_emb = model.src_embed(src) * math.sqrt(model.d_model)
134
+ enc_x = model.src_pe(src_emb)
135
+ for i, layer in enumerate(model.encoder_layers):
136
+ enc_x, _ = layer(enc_x, src_mask=src_mask)
137
+
138
+ # Beams: list of (score, token_ids)
139
+ beams = [(0.0, [BOS_IDX])]
140
+ completed = []
141
+ beam_step_logs = []
142
+
143
+ for step in range(max_len):
144
+ if not beams:
145
+ break
146
+ candidates = []
147
+
148
+ for beam_idx, (score, tokens) in enumerate(beams):
149
+ tgt_t = torch.tensor([tokens], dtype=torch.long, device=device)
150
+ tgt_mask = model.make_tgt_mask(tgt_t)
151
+ tgt_emb = model.tgt_embed(tgt_t) * math.sqrt(model.d_model)
152
+ dec_x = model.tgt_pe(tgt_emb)
153
+ for i, layer in enumerate(model.decoder_layers):
154
+ dec_x, _, _ = layer(dec_x, enc_x,
155
+ tgt_mask=tgt_mask, src_mask=src_mask)
156
+ last_logits = model.output_linear(dec_x[:, -1, :])
157
+ log_probs = F.log_softmax(last_logits, dim=-1)[0]
158
+ top_lp, top_id = log_probs.topk(beam_size)
159
+ for lp, tid in zip(top_lp, top_id):
160
+ new_score = score + lp.item()
161
+ new_tokens = tokens + [tid.item()]
162
+ candidates.append((new_score, new_tokens))
163
+
164
+ # Keep top beam_size
165
+ candidates.sort(key=lambda x: x[0], reverse=True)
166
+ beams = []
167
+ step_info = {"step": step, "beams": []}
168
+ for sc, toks in candidates[:beam_size * 2]:
169
+ if toks[-1] == EOS_IDX:
170
+ completed.append((sc / len(toks), toks))
171
+ else:
172
+ beams.append((sc, toks))
173
+ step_info["beams"].append({
174
+ "score": round(sc, 4),
175
+ "tokens": tgt_v.tokens(toks),
176
+ "text": tgt_v.decode(toks),
177
+ })
178
+ if len(beams) == beam_size:
179
+ break
180
+ beam_step_logs.append(step_info)
181
+
182
+ if len(completed) >= beam_size:
183
+ break
184
+
185
+ if completed:
186
+ best = max(completed, key=lambda x: x[0])
187
+ return best[1], beam_step_logs
188
+ elif beams:
189
+ return beams[0][1] + [EOS_IDX], beam_step_logs
190
+ else:
191
+ return [BOS_IDX, EOS_IDX], beam_step_logs
192
+
193
+
194
+ # ─────────────────────────────────────────────
195
+ # Full inference pipeline with visualization
196
+ # ─────────────────────────────────────────────
197
+
198
+ def visualize_inference(
199
+ model: Transformer,
200
+ en_sentence: str,
201
+ device: str = "cpu",
202
+ decode_method: str = "greedy",
203
+ ) -> Dict:
204
+ src_v, tgt_v = get_vocabs()
205
+ log = CalcLog()
206
+
207
+ src_ids = src_v.encode(en_sentence)
208
+ log.log("INFERENCE_TOKENIZATION", {
209
+ "sentence": en_sentence,
210
+ "tokens": en_sentence.lower().split(),
211
+ "ids": src_ids,
212
+ }, formula="word → vocab_id lookup",
213
+ note="No ground-truth Bengali needed — model generates from scratch")
214
+
215
+ src = torch.tensor([src_ids], dtype=torch.long, device=device)
216
+
217
+ if decode_method == "beam":
218
+ output_ids, step_logs = beam_search(model, src, beam_size=3,
219
+ device=device, log=log)
220
+ log.log("BEAM_SEARCH_complete", {
221
+ "method": "beam search (beam=3)",
222
+ "note": "Explores multiple hypotheses simultaneously — generally better quality"
223
+ })
224
+ else:
225
+ output_ids, step_logs = greedy_decode(model, src, device=device, log=log)
226
+ log.log("GREEDY_complete", {
227
+ "method": "greedy decoding",
228
+ "note": "Always picks highest probability token — fast but can miss optimal sequences"
229
+ })
230
+
231
+ translation = tgt_v.decode(output_ids)
232
+ output_tokens = tgt_v.tokens(output_ids)
233
+
234
+ log.log("FINAL_TRANSLATION", {
235
+ "input": en_sentence,
236
+ "output_ids": output_ids,
237
+ "output_tokens": output_tokens,
238
+ "translation": translation,
239
+ }, note="Complete English→Bengali translation")
240
+
241
+ return {
242
+ "en_sentence": en_sentence,
243
+ "translation": translation,
244
+ "output_tokens": output_tokens,
245
+ "output_ids": output_ids,
246
+ "src_tokens": src_v.tokens(src_ids),
247
+ "step_logs": step_logs,
248
+ "calc_log": log.to_dict(),
249
+ "decode_method": decode_method,
250
+ }
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch>=2.0.0
2
+ numpy>=1.24.0
training.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ training.py
3
+ Training loop for English→Bengali transformer with full calculation capture.
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import numpy as np
10
+ import math
11
+ from typing import Dict, List, Tuple, Optional
12
+
13
+ from transformer import Transformer, CalcLog
14
+ from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX, BOS_IDX, EOS_IDX
15
+
16
+
17
+ # ─────────────────────────────────────────────
18
+ # Data helpers
19
+ # ─────────────────────────────────────────────
20
+
21
+ def collate_batch(pairs: List[Tuple[str, str]], src_v, tgt_v, device: str = "cpu"):
22
+ src_seqs, tgt_seqs = [], []
23
+ for en, bn in pairs:
24
+ src_seqs.append(src_v.encode(en))
25
+ tgt_seqs.append(tgt_v.encode(bn))
26
+
27
+ def pad(seqs):
28
+ max_len = max(len(s) for s in seqs)
29
+ padded = [s + [PAD_IDX] * (max_len - len(s)) for s in seqs]
30
+ return torch.tensor(padded, dtype=torch.long, device=device)
31
+
32
+ return pad(src_seqs), pad(tgt_seqs)
33
+
34
+
35
+ # ─────────────────────────────────────────────
36
+ # Label-smoothed cross-entropy
37
+ # ─────────────────────────────────────────────
38
+
39
+ class LabelSmoothingLoss(nn.Module):
40
+ def __init__(self, vocab_size: int, pad_idx: int, smoothing: float = 0.1):
41
+ super().__init__()
42
+ self.vocab_size = vocab_size
43
+ self.pad_idx = pad_idx
44
+ self.smoothing = smoothing
45
+ self.confidence = 1.0 - smoothing
46
+
47
+ def forward(self, logits: torch.Tensor, target: torch.Tensor,
48
+ log: Optional[CalcLog] = None) -> torch.Tensor:
49
+ B, T, V = logits.shape
50
+ logits_flat = logits.reshape(-1, V)
51
+ target_flat = target.reshape(-1)
52
+
53
+ log_probs = torch.log_softmax(logits_flat, dim=-1)
54
+
55
+ with torch.no_grad():
56
+ smooth_dist = torch.full_like(log_probs, self.smoothing / (V - 2))
57
+ smooth_dist.scatter_(1, target_flat.unsqueeze(1), self.confidence)
58
+ smooth_dist[:, self.pad_idx] = 0
59
+ mask = (target_flat == self.pad_idx)
60
+ smooth_dist[mask] = 0
61
+
62
+ loss = -(smooth_dist * log_probs).sum(dim=-1)
63
+ non_pad = (~mask).sum()
64
+ loss = loss.sum() / non_pad.clamp(min=1)
65
+
66
+ if log:
67
+ probs_sample = torch.exp(log_probs[:4])
68
+ log.log("LOSS_log_probs_sample", probs_sample,
69
+ formula="log P(token) = log_softmax(logits)",
70
+ note="Softmax probabilities for first 4 target positions")
71
+ log.log("LOSS_smooth_dist_sample", smooth_dist[:4],
72
+ formula=f"smooth: correct={self.confidence:.2f}, others={self.smoothing/(V-2):.5f}",
73
+ note="Label-smoothed target distribution")
74
+ log.log("LOSS_value", loss.item(),
75
+ formula="L = -Σ smooth_dist · log_probs / n_tokens",
76
+ note=f"Label-smoothed cross-entropy loss = {loss.item():.4f}")
77
+
78
+ return loss
79
+
80
+
81
+ # ─────────────────────────────────────────────
82
+ # Build model
83
+ # ─────────────────────────────────────────────
84
+
85
+ def build_model(src_vocab_size: int, tgt_vocab_size: int,
86
+ device: str = "cpu") -> Transformer:
87
+ model = Transformer(
88
+ src_vocab_size=src_vocab_size,
89
+ tgt_vocab_size=tgt_vocab_size,
90
+ d_model=64,
91
+ num_heads=4,
92
+ num_layers=2,
93
+ d_ff=128,
94
+ max_len=32,
95
+ dropout=0.1,
96
+ pad_idx=PAD_IDX,
97
+ ).to(device)
98
+ return model
99
+
100
+
101
+ # ─────────────────────────────────────────────
102
+ # Single training step (with full logging)
103
+ # ─────────────────────────────────────────────
104
+
105
+ def training_step(
106
+ model: Transformer,
107
+ src: torch.Tensor,
108
+ tgt: torch.Tensor,
109
+ criterion: LabelSmoothingLoss,
110
+ optimizer: optim.Optimizer,
111
+ log: CalcLog,
112
+ step_num: int = 0,
113
+ ) -> Dict:
114
+ model.train()
115
+ log.clear()
116
+
117
+ # Teacher forcing: decoder input = [BOS, token_1, ..., token_{T-1}]
118
+ tgt_input = tgt[:, :-1]
119
+ tgt_target = tgt[:, 1:]
120
+
121
+ log.log("TRAINING_SETUP", {
122
+ "mode": "TRAINING",
123
+ "step": step_num,
124
+ "src_shape": list(src.shape),
125
+ "tgt_input_shape": list(tgt_input.shape),
126
+ "tgt_target_shape": list(tgt_target.shape),
127
+ }, formula="Teacher Forcing: feed ground-truth Bengali tokens as decoder input",
128
+ note="During training, decoder sees actual Bengali tokens (not its own predictions)")
129
+
130
+ log.log("SRC_sentence_ids", src[0].tolist(),
131
+ note="Source (English) token IDs fed to encoder")
132
+ log.log("TGT_input_ids", tgt_input[0].tolist(),
133
+ note="Target input to decoder (shifted right — starts with <BOS>)")
134
+ log.log("TGT_target_ids", tgt_target[0].tolist(),
135
+ note="What decoder must predict (shifted left — ends with <EOS>)")
136
+
137
+ # Forward
138
+ logits, meta = model(src, tgt_input, log=log)
139
+
140
+ # Loss
141
+ loss = criterion(logits, tgt_target, log=log)
142
+
143
+ log.log("LOSS_final", loss.item(),
144
+ formula="Total loss = label-smoothed cross-entropy averaged over tokens",
145
+ note=f"Loss = {loss.item():.4f} (lower = better prediction)")
146
+
147
+ # Backward
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+
151
+ # Gradient stats
152
+ grad_norms = {}
153
+ for name, param in model.named_parameters():
154
+ if param.grad is not None:
155
+ gn = param.grad.norm().item()
156
+ grad_norms[name] = round(gn, 6)
157
+
158
+ log.log("GRADIENTS_norm_sample", dict(list(grad_norms.items())[:8]),
159
+ formula="∂L/∂W via backpropagation (chain rule)",
160
+ note="Gradient norms for first 8 parameter tensors")
161
+
162
+ # Gradient clipping
163
+ nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
164
+ optimizer.step()
165
+
166
+ log.log("OPTIMIZER_step", {
167
+ "algorithm": "Adam",
168
+ "lr": optimizer.param_groups[0]["lr"],
169
+ "note": "W = W - lr × (m̂ / (√v̂ + ε)) (Adam update rule)",
170
+ }, formula="Adam: adaptive learning rate with momentum",
171
+ note="Weights updated — model slightly improved")
172
+
173
+ return {
174
+ "loss": loss.item(),
175
+ "calc_log": log.to_dict(),
176
+ "meta": {k: v.tolist() if hasattr(v, "tolist") else v for k, v in meta.items()
177
+ if k != "enc_attn"},
178
+ }
179
+
180
+
181
+ # ─────────────────────────────────────────────
182
+ # Full training run (quick demo)
183
+ # ─────────────────────────────────────────────
184
+
185
+ def run_training(
186
+ epochs: int = 30,
187
+ device: str = "cpu",
188
+ progress_cb=None,
189
+ ) -> Tuple[Transformer, List[float]]:
190
+ src_v, tgt_v = get_vocabs()
191
+ model = build_model(len(src_v), len(tgt_v), device)
192
+ criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX, smoothing=0.1)
193
+ optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
194
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
195
+
196
+ losses = []
197
+ src_batch, tgt_batch = collate_batch(PARALLEL_DATA, src_v, tgt_v, device)
198
+
199
+ for epoch in range(1, epochs + 1):
200
+ model.train()
201
+ tgt_input = tgt_batch[:, :-1]
202
+ tgt_target = tgt_batch[:, 1:]
203
+ logits, _ = model(src_batch, tgt_input, log=None)
204
+ loss = criterion(logits, tgt_target)
205
+ optimizer.zero_grad()
206
+ loss.backward()
207
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
208
+ optimizer.step()
209
+ scheduler.step(loss.item())
210
+ losses.append(loss.item())
211
+ if progress_cb:
212
+ progress_cb(epoch, epochs, loss.item())
213
+
214
+ return model, losses
215
+
216
+
217
+ # ─────────────────────────────────────────────
218
+ # Single-sample step for visualization
219
+ # ─────────────────────────────────────────────
220
+
221
+ def visualize_training_step(
222
+ model: Transformer,
223
+ en_sentence: str,
224
+ bn_sentence: str,
225
+ device: str = "cpu",
226
+ ) -> Dict:
227
+ src_v, tgt_v = get_vocabs()
228
+ log = CalcLog()
229
+
230
+ src_ids = src_v.encode(en_sentence)
231
+ tgt_ids = tgt_v.encode(bn_sentence)
232
+
233
+ log.log("TOKENIZATION_EN", {
234
+ "sentence": en_sentence,
235
+ "tokens": en_sentence.lower().split(),
236
+ "ids": src_ids,
237
+ "vocab_size": len(src_v),
238
+ }, formula="token_id = vocab[word]",
239
+ note="English → token IDs (BOS prepended, EOS appended)")
240
+
241
+ log.log("TOKENIZATION_BN", {
242
+ "sentence": bn_sentence,
243
+ "tokens": bn_sentence.split(),
244
+ "ids": tgt_ids,
245
+ "vocab_size": len(tgt_v),
246
+ }, note="Bengali → token IDs (teacher-forced during training)")
247
+
248
+ src = torch.tensor([src_ids], dtype=torch.long, device=device)
249
+ tgt = torch.tensor([tgt_ids], dtype=torch.long, device=device)
250
+
251
+ criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX)
252
+ optimizer = optim.Adam(model.parameters(), lr=1e-3)
253
+
254
+ result = training_step(model, src, tgt, criterion, optimizer, log)
255
+
256
+ src_v_obj, tgt_v_obj = get_vocabs()
257
+ result["src_tokens"] = src_v_obj.tokens(src_ids)
258
+ result["tgt_tokens"] = tgt_v_obj.tokens(tgt_ids)
259
+ result["en_sentence"] = en_sentence
260
+ result["bn_sentence"] = bn_sentence
261
+ return result
transformer.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ transformer.py
3
+ Full Transformer implementation for English → Bengali translation
4
+ with complete calculation tracking at every step.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+ import math
12
+ from typing import Optional, Tuple, Dict, List
13
+
14
+
15
+ # ─────────────────────────────────────────────
16
+ # Calculation Logger
17
+ # ─────────────────────────────────────────────
18
+
19
+ class CalcLog:
20
+ """Captures every intermediate tensor for visualization."""
21
+ def __init__(self):
22
+ self.steps: List[Dict] = []
23
+
24
+ def log(self, name: str, data, formula: str = "", note: str = ""):
25
+ entry = {
26
+ "name": name,
27
+ "formula": formula,
28
+ "note": note,
29
+ "shape": None,
30
+ "value": None,
31
+ }
32
+ if isinstance(data, torch.Tensor):
33
+ entry["shape"] = list(data.shape)
34
+ entry["value"] = data.detach().cpu().numpy().tolist()
35
+ elif isinstance(data, np.ndarray):
36
+ entry["shape"] = list(data.shape)
37
+ entry["value"] = data.tolist()
38
+ else:
39
+ entry["value"] = data
40
+ self.steps.append(entry)
41
+ return data
42
+
43
+ def clear(self):
44
+ self.steps = []
45
+
46
+ def to_dict(self):
47
+ return self.steps
48
+
49
+
50
+ # ─────────────────────────────────────────────
51
+ # Positional Encoding
52
+ # ─────────────────────────────────────────────
53
+
54
+ class PositionalEncoding(nn.Module):
55
+ def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
56
+ super().__init__()
57
+ self.d_model = d_model
58
+ self.dropout = nn.Dropout(dropout)
59
+
60
+ pe = torch.zeros(max_len, d_model)
61
+ position = torch.arange(0, max_len).unsqueeze(1).float()
62
+ div_term = torch.exp(
63
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
64
+ )
65
+ pe[:, 0::2] = torch.sin(position * div_term)
66
+ pe[:, 1::2] = torch.cos(position * div_term)
67
+ self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
68
+
69
+ def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None) -> torch.Tensor:
70
+ seq_len = x.size(1)
71
+ pe_slice = self.pe[:, :seq_len, :]
72
+
73
+ if log:
74
+ log.log("PE_matrix", pe_slice[0, :seq_len, :8],
75
+ formula="PE(pos,2i)=sin(pos/10000^(2i/d)), PE(pos,2i+1)=cos(...)",
76
+ note=f"Showing first 8 dims for {seq_len} positions")
77
+ log.log("Embedding_before_PE", x[0, :, :8],
78
+ note="Token embeddings (first 8 dims)")
79
+
80
+ x = x + pe_slice
81
+ if log:
82
+ log.log("Embedding_after_PE", x[0, :, :8],
83
+ formula="X = Embedding + PE",
84
+ note="After adding positional encoding")
85
+ return self.dropout(x)
86
+
87
+
88
+ # ─────────────────────────────────────────────
89
+ # Scaled Dot-Product Attention
90
+ # ─────────────────────────────────────────────
91
+
92
+ def scaled_dot_product_attention(
93
+ Q: torch.Tensor,
94
+ K: torch.Tensor,
95
+ V: torch.Tensor,
96
+ mask: Optional[torch.Tensor] = None,
97
+ log: Optional[CalcLog] = None,
98
+ head_idx: int = 0,
99
+ layer_idx: int = 0,
100
+ attn_type: str = "self",
101
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
102
+ d_k = Q.size(-1)
103
+ prefix = f"L{layer_idx}_H{head_idx}_{attn_type}"
104
+
105
+ # Raw scores
106
+ scores = torch.matmul(Q, K.transpose(-2, -1))
107
+ if log:
108
+ log.log(f"{prefix}_Q", Q[0],
109
+ formula="Q = X · Wq",
110
+ note=f"Query matrix head {head_idx}")
111
+ log.log(f"{prefix}_K", K[0],
112
+ formula="K = X · Wk",
113
+ note=f"Key matrix head {head_idx}")
114
+ log.log(f"{prefix}_V", V[0],
115
+ formula="V = X · Wv",
116
+ note=f"Value matrix head {head_idx}")
117
+ log.log(f"{prefix}_QKt", scores[0],
118
+ formula="scores = Q · Kᵀ",
119
+ note=f"Raw attention scores (before scaling)")
120
+
121
+ # Scale
122
+ scale = math.sqrt(d_k)
123
+ scores = scores / scale
124
+ if log:
125
+ log.log(f"{prefix}_QKt_scaled", scores[0],
126
+ formula=f"scores = Q·Kᵀ / √{d_k} = Q·Kᵀ / {scale:.3f}",
127
+ note="Scaled scores — prevents vanishing gradients")
128
+
129
+ # Mask
130
+ # masks arrive as (B,1,1,T) or (B,1,T,T) from make_src/tgt_mask;
131
+ # scores here are 3-D (B,T_q,T_k) because we loop per-head,
132
+ # so squeeze the head dim to avoid (B,B,...) broadcasting.
133
+ if mask is not None:
134
+ if mask.dim() == 4:
135
+ mask = mask.squeeze(1) # (B,1,T,T) or (B,1,1,T) → (B,T,T) or (B,1,T)
136
+ scores = scores.masked_fill(mask == 0, float("-inf"))
137
+ if log:
138
+ log.log(f"{prefix}_mask", mask[0].float(),
139
+ formula="mask[i,j]=0 → score=-inf (future token blocked)",
140
+ note="Causal mask (training decoder) or padding mask")
141
+ log.log(f"{prefix}_scores_masked", scores[0],
142
+ note="Scores after masking (-inf will become 0 after softmax)")
143
+
144
+ # Softmax
145
+ attn_weights = F.softmax(scores, dim=-1)
146
+ # replace nan from -inf rows with 0 (edge case)
147
+ attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
148
+ if log:
149
+ log.log(f"{prefix}_softmax", attn_weights[0],
150
+ formula="α = softmax(scores, dim=-1)",
151
+ note="Attention weights — each row sums to 1.0")
152
+
153
+ # Weighted sum
154
+ output = torch.matmul(attn_weights, V)
155
+ if log:
156
+ log.log(f"{prefix}_output", output[0],
157
+ formula="Attention = α · V",
158
+ note="Weighted sum of values")
159
+
160
+ return output, attn_weights
161
+
162
+
163
+ # ─────────────────────────────────────────────
164
+ # Multi-Head Attention
165
+ # ─────────────────────────────────────────────
166
+
167
+ class MultiHeadAttention(nn.Module):
168
+ def __init__(self, d_model: int, num_heads: int):
169
+ super().__init__()
170
+ assert d_model % num_heads == 0
171
+ self.d_model = d_model
172
+ self.num_heads = num_heads
173
+ self.d_k = d_model // num_heads
174
+
175
+ self.W_q = nn.Linear(d_model, d_model, bias=False)
176
+ self.W_k = nn.Linear(d_model, d_model, bias=False)
177
+ self.W_v = nn.Linear(d_model, d_model, bias=False)
178
+ self.W_o = nn.Linear(d_model, d_model, bias=False)
179
+
180
+ def split_heads(self, x: torch.Tensor) -> torch.Tensor:
181
+ B, T, D = x.shape
182
+ return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
183
+ # → (B, num_heads, T, d_k)
184
+
185
+ def forward(
186
+ self,
187
+ query: torch.Tensor,
188
+ key: torch.Tensor,
189
+ value: torch.Tensor,
190
+ mask: Optional[torch.Tensor] = None,
191
+ log: Optional[CalcLog] = None,
192
+ layer_idx: int = 0,
193
+ attn_type: str = "self",
194
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
195
+ B = query.size(0)
196
+ prefix = f"L{layer_idx}_{attn_type}_MHA"
197
+
198
+ # Linear projections
199
+ Q = self.W_q(query)
200
+ K = self.W_k(key)
201
+ V = self.W_v(value)
202
+
203
+ if log:
204
+ log.log(f"{prefix}_Wq", self.W_q.weight[:4, :4],
205
+ formula="Wq shape: (d_model, d_model)",
206
+ note=f"Query weight matrix (first 4×4 shown)")
207
+ log.log(f"{prefix}_Q_full", Q[0, :, :8],
208
+ formula="Q = input · Wq",
209
+ note=f"Full Q projection (first 8 dims shown)")
210
+
211
+ # Split into heads
212
+ Q = self.split_heads(Q) # (B, h, T, d_k)
213
+ K = self.split_heads(K)
214
+ V = self.split_heads(V)
215
+
216
+ if log:
217
+ log.log(f"{prefix}_Q_head0", Q[0, 0, :, :],
218
+ formula=f"Split: (B,T,D) → (B,{self.num_heads},T,{self.d_k})",
219
+ note=f"Head 0 queries — d_k={self.d_k}")
220
+
221
+ # Per-head attention (log only first 2 heads to avoid bloat)
222
+ all_attn = []
223
+ all_weights = []
224
+ for h in range(self.num_heads):
225
+ h_log = log if h < 2 else None
226
+ out_h, w_h = scaled_dot_product_attention(
227
+ Q[:, h], K[:, h], V[:, h],
228
+ mask=mask,
229
+ log=h_log,
230
+ head_idx=h,
231
+ layer_idx=layer_idx,
232
+ attn_type=attn_type,
233
+ )
234
+ all_attn.append(out_h)
235
+ all_weights.append(w_h)
236
+
237
+ # Concat heads
238
+ concat = torch.stack(all_attn, dim=1) # (B, h, T, d_k)
239
+ concat = concat.transpose(1, 2).contiguous() # (B, T, h, d_k)
240
+ concat = concat.view(B, -1, self.d_model) # (B, T, D)
241
+
242
+ if log:
243
+ log.log(f"{prefix}_concat", concat[0, :, :8],
244
+ formula="concat = [head_1; head_2; ...; head_h]",
245
+ note=f"Concatenated heads (first 8 dims)")
246
+
247
+ # Final projection
248
+ output = self.W_o(concat)
249
+ if log:
250
+ log.log(f"{prefix}_output", output[0, :, :8],
251
+ formula="MHA_out = concat · Wo",
252
+ note="Final multi-head attention output")
253
+
254
+ # Stack all attention weights: (B, h, T_q, T_k)
255
+ attn_weights = torch.stack(all_weights, dim=1)
256
+ return output, attn_weights
257
+
258
+
259
+ # ─────────────────────────────────────────────
260
+ # Feed-Forward Network
261
+ # ────────────────────────────���────────────────
262
+
263
+ class FeedForward(nn.Module):
264
+ def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
265
+ super().__init__()
266
+ self.linear1 = nn.Linear(d_model, d_ff)
267
+ self.linear2 = nn.Linear(d_ff, d_model)
268
+ self.dropout = nn.Dropout(dropout)
269
+
270
+ def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None,
271
+ layer_idx: int = 0, loc: str = "enc") -> torch.Tensor:
272
+ prefix = f"L{layer_idx}_{loc}_FFN"
273
+ h = self.linear1(x)
274
+ if log:
275
+ log.log(f"{prefix}_linear1", h[0, :, :8],
276
+ formula="h = X · W1 + b1",
277
+ note=f"First linear (d_model→d_ff), showing first 8 dims")
278
+ h = F.relu(h)
279
+ if log:
280
+ log.log(f"{prefix}_relu", h[0, :, :8],
281
+ formula="h = ReLU(h) = max(0, h)",
282
+ note="Negative values zeroed out")
283
+ h = self.dropout(h)
284
+ out = self.linear2(h)
285
+ if log:
286
+ log.log(f"{prefix}_linear2", out[0, :, :8],
287
+ formula="out = h · W2 + b2",
288
+ note=f"Second linear (d_ff→d_model)")
289
+ return out
290
+
291
+
292
+ # ─────────────────────────────────────────────
293
+ # Layer Norm + Residual
294
+ # ─────────────────────────────────────────────
295
+
296
+ class AddNorm(nn.Module):
297
+ def __init__(self, d_model: int, eps: float = 1e-6):
298
+ super().__init__()
299
+ self.norm = nn.LayerNorm(d_model, eps=eps)
300
+
301
+ def forward(self, x: torch.Tensor, sublayer_out: torch.Tensor,
302
+ log: Optional[CalcLog] = None, tag: str = "") -> torch.Tensor:
303
+ residual = x + sublayer_out
304
+ out = self.norm(residual)
305
+ if log:
306
+ log.log(f"{tag}_residual", residual[0, :, :8],
307
+ formula="residual = x + sublayer(x)",
308
+ note="Residual (skip) connection")
309
+ log.log(f"{tag}_layernorm", out[0, :, :8],
310
+ formula="LayerNorm(x) = γ·(x−μ)/σ + β",
311
+ note="Layer normalization output")
312
+ return out
313
+
314
+
315
+ # ─────────────────────────────────────────────
316
+ # Encoder Layer
317
+ # ─────────────────────────────────────────────
318
+
319
+ class EncoderLayer(nn.Module):
320
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
321
+ super().__init__()
322
+ self.self_attn = MultiHeadAttention(d_model, num_heads)
323
+ self.ffn = FeedForward(d_model, d_ff, dropout)
324
+ self.add_norm1 = AddNorm(d_model)
325
+ self.add_norm2 = AddNorm(d_model)
326
+
327
+ def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None,
328
+ log: Optional[CalcLog] = None, layer_idx: int = 0):
329
+ attn_out, attn_w = self.self_attn(
330
+ x, x, x, mask=src_mask, log=log,
331
+ layer_idx=layer_idx, attn_type="enc_self"
332
+ )
333
+ x = self.add_norm1(x, attn_out, log=log, tag=f"L{layer_idx}_enc_self")
334
+ ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="enc")
335
+ x = self.add_norm2(x, ffn_out, log=log, tag=f"L{layer_idx}_enc_ffn")
336
+ return x, attn_w
337
+
338
+
339
+ # ─────────────────────────────────────────────
340
+ # Decoder Layer
341
+ # ─────────────────────────────────────────────
342
+
343
+ class DecoderLayer(nn.Module):
344
+ def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
345
+ super().__init__()
346
+ self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
347
+ self.cross_attn = MultiHeadAttention(d_model, num_heads)
348
+ self.ffn = FeedForward(d_model, d_ff, dropout)
349
+ self.add_norm1 = AddNorm(d_model)
350
+ self.add_norm2 = AddNorm(d_model)
351
+ self.add_norm3 = AddNorm(d_model)
352
+
353
+ def forward(
354
+ self,
355
+ x: torch.Tensor,
356
+ enc_out: torch.Tensor,
357
+ tgt_mask: Optional[torch.Tensor] = None,
358
+ src_mask: Optional[torch.Tensor] = None,
359
+ log: Optional[CalcLog] = None,
360
+ layer_idx: int = 0,
361
+ ):
362
+ # 1. Masked self-attention
363
+ m_attn_out, m_attn_w = self.masked_self_attn(
364
+ x, x, x, mask=tgt_mask, log=log,
365
+ layer_idx=layer_idx, attn_type="dec_masked"
366
+ )
367
+ x = self.add_norm1(x, m_attn_out, log=log, tag=f"L{layer_idx}_dec_masked")
368
+
369
+ # 2. Cross-attention: Q from decoder, K/V from encoder
370
+ if log:
371
+ log.log(f"L{layer_idx}_cross_Q_source", x[0, :, :8],
372
+ note="Cross-attn Q comes from DECODER (Bengali context)")
373
+ log.log(f"L{layer_idx}_cross_KV_source", enc_out[0, :, :8],
374
+ note="Cross-attn K,V come from ENCODER (English context)")
375
+
376
+ c_attn_out, c_attn_w = self.cross_attn(
377
+ query=x, key=enc_out, value=enc_out,
378
+ mask=src_mask, log=log,
379
+ layer_idx=layer_idx, attn_type="dec_cross"
380
+ )
381
+ x = self.add_norm2(x, c_attn_out, log=log, tag=f"L{layer_idx}_dec_cross")
382
+
383
+ # 3. FFN
384
+ ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="dec")
385
+ x = self.add_norm3(x, ffn_out, log=log, tag=f"L{layer_idx}_dec_ffn")
386
+
387
+ return x, m_attn_w, c_attn_w
388
+
389
+
390
+ # ─────────────────────────────────────────────
391
+ # Full Transformer
392
+ # ─────────────────────────────────────────────
393
+
394
+ class Transformer(nn.Module):
395
+ def __init__(
396
+ self,
397
+ src_vocab_size: int,
398
+ tgt_vocab_size: int,
399
+ d_model: int = 128,
400
+ num_heads: int = 4,
401
+ num_layers: int = 2,
402
+ d_ff: int = 256,
403
+ max_len: int = 64,
404
+ dropout: float = 0.1,
405
+ pad_idx: int = 0,
406
+ ):
407
+ super().__init__()
408
+ self.d_model = d_model
409
+ self.pad_idx = pad_idx
410
+ self.num_layers = num_layers
411
+
412
+ self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
413
+ self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
414
+ self.src_pe = PositionalEncoding(d_model, max_len, dropout)
415
+ self.tgt_pe = PositionalEncoding(d_model, max_len, dropout)
416
+
417
+ self.encoder_layers = nn.ModuleList(
418
+ [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
419
+ )
420
+ self.decoder_layers = nn.ModuleList(
421
+ [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
422
+ )
423
+
424
+ self.output_linear = nn.Linear(d_model, tgt_vocab_size)
425
+ self._init_weights()
426
+
427
+ def _init_weights(self):
428
+ for p in self.parameters():
429
+ if p.dim() > 1:
430
+ nn.init.xavier_uniform_(p)
431
+
432
+ # ── mask helpers ──────────────────────────
433
+
434
+ def make_src_mask(self, src: torch.Tensor) -> torch.Tensor:
435
+ # (B, 1, 1, T_src) — 1 where not pad
436
+ return (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
437
+
438
+ def make_tgt_mask(self, tgt: torch.Tensor) -> torch.Tensor:
439
+ T = tgt.size(1)
440
+ pad_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,T)
441
+ causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() # (T,T)
442
+ return pad_mask & causal # (B,1,T,T)
443
+
444
+ # ── forward ───────────────────────────────
445
+
446
+ def forward(
447
+ self,
448
+ src: torch.Tensor,
449
+ tgt: torch.Tensor,
450
+ log: Optional[CalcLog] = None,
451
+ ) -> Tuple[torch.Tensor, Dict]:
452
+ src_mask = self.make_src_mask(src)
453
+ tgt_mask = self.make_tgt_mask(tgt)
454
+
455
+ # ── Encoder ──────────────────────────
456
+ src_emb = self.src_embed(src) * math.sqrt(self.d_model)
457
+ if log:
458
+ log.log("SRC_tokens", src[0],
459
+ note="Source token IDs (English)")
460
+ log.log("SRC_embedding_raw", src_emb[0, :, :8],
461
+ formula=f"emb = Embedding(token_id) × √{self.d_model}",
462
+ note="Token embeddings (first 8 dims)")
463
+
464
+ enc_x = self.src_pe(src_emb, log=log)
465
+
466
+ enc_attn_weights = []
467
+ for i, layer in enumerate(self.encoder_layers):
468
+ enc_x, ew = layer(enc_x, src_mask=src_mask, log=log, layer_idx=i)
469
+ enc_attn_weights.append(ew.detach().cpu().numpy())
470
+
471
+ if log:
472
+ log.log("ENCODER_output", enc_x[0, :, :8],
473
+ note="Final encoder output — passed as K,V to every decoder cross-attention")
474
+
475
+ # ── Decoder ──────────────────────────
476
+ tgt_emb = self.tgt_embed(tgt) * math.sqrt(self.d_model)
477
+ if log:
478
+ log.log("TGT_tokens", tgt[0],
479
+ note="Target token IDs (Bengali, teacher-forced in training)")
480
+ log.log("TGT_embedding_raw", tgt_emb[0, :, :8],
481
+ formula=f"emb = Embedding(token_id) × √{self.d_model}",
482
+ note="Bengali token embeddings")
483
+
484
+ dec_x = self.tgt_pe(tgt_emb, log=log)
485
+
486
+ dec_self_attn_w = []
487
+ dec_cross_attn_w = []
488
+ for i, layer in enumerate(self.decoder_layers):
489
+ dec_x, mw, cw = layer(
490
+ dec_x, enc_x,
491
+ tgt_mask=tgt_mask, src_mask=src_mask,
492
+ log=log, layer_idx=i,
493
+ )
494
+ dec_self_attn_w.append(mw.detach().cpu().numpy())
495
+ dec_cross_attn_w.append(cw.detach().cpu().numpy())
496
+
497
+ # ── Output projection ─────────────────
498
+ logits = self.output_linear(dec_x) # (B, T, vocab)
499
+ if log:
500
+ log.log("LOGITS", logits[0, :, :16],
501
+ formula="logits = dec_out · W_out (first 16 vocab entries shown)",
502
+ note=f"Raw scores over vocab of {logits.size(-1)} Bengali tokens")
503
+
504
+ probs = F.softmax(logits[0], dim=-1)
505
+ log.log("SOFTMAX_probs", probs[:, :16],
506
+ formula="P(token) = exp(logit) / Σ exp(logits)",
507
+ note="Probability distribution over Bengali vocabulary")
508
+
509
+ meta = {
510
+ "enc_attn": enc_attn_weights,
511
+ "dec_self_attn": dec_self_attn_w,
512
+ "dec_cross_attn": dec_cross_attn_w,
513
+ "src_mask": src_mask.cpu().numpy(),
514
+ "tgt_mask": tgt_mask.cpu().numpy(),
515
+ }
516
+ return logits, meta
vocab.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vocab.py
3
+ Simple character/subword tokenizer for English→Bengali demo.
4
+ """
5
+
6
+ import json
7
+ import re
8
+ from pathlib import Path
9
+ from typing import List, Dict, Tuple
10
+
11
+
12
+ # ── Special tokens ───────────────────────────────────────────────────────────
13
+
14
+ PAD_TOKEN = "<pad>"
15
+ BOS_TOKEN = "<bos>"
16
+ EOS_TOKEN = "<eos>"
17
+ UNK_TOKEN = "<unk>"
18
+
19
+ PAD_IDX = 0
20
+ BOS_IDX = 1
21
+ EOS_IDX = 2
22
+ UNK_IDX = 3
23
+
24
+
25
+ # ── English word-level vocab ─────────────────────────────────────────────────
26
+
27
+ EN_WORDS = [
28
+ "i", "you", "he", "she", "we", "they", "it",
29
+ "love", "like", "eat", "drink", "go", "come", "see", "know",
30
+ "want", "need", "have", "am", "is", "are", "was", "were",
31
+ "do", "does", "did", "will", "can", "may", "should",
32
+ "a", "an", "the", "this", "that", "my", "your", "his", "her",
33
+ "good", "bad", "happy", "sad", "big", "small", "new", "old",
34
+ "food", "water", "home", "work", "school", "book", "name",
35
+ "rice", "fish", "milk", "tea", "coffee",
36
+ "hello", "bye", "yes", "no", "please", "thank", "thanks",
37
+ "how", "what", "where", "when", "why", "who",
38
+ "today", "tomorrow", "now", "always", "never", "very",
39
+ "bengal", "india", "english", "bengali",
40
+ "beautiful", "wonderful", "great", "nice", "fine",
41
+ ]
42
+
43
+ # ── Bengali word-level vocab ──────────────────────────────────────────────────
44
+
45
+ BN_WORDS = [
46
+ "আমি", "তুমি", "তুই", "সে", "আমরা", "তারা", "এটা",
47
+ "ভালোবাসি", "পছন্দ", "খাই", "পান", "যাই", "আসি", "দেখি", "জানি",
48
+ "চাই", "দরকার", "আছে", "হই", "হয়", "ছিলাম", "ছিল",
49
+ "করি", "করে", "করেছি", "করব", "পারি", "পারে", "উচিত",
50
+ "একটা", "একটি", "এই", "সেই", "আমার", "তোমার", "তার",
51
+ "ভালো", "খারাপ", "খুশি", "দুঃখী", "বড়", "ছোট", "নতুন", "পুরনো",
52
+ "খাবার", "পানি", "বাড়ি", "কাজ", "স্কুল", "বই", "নাম",
53
+ "ভাত", "মাছ", "দুধ", "চা", "কফি",
54
+ "হ্যালো", "বিদায়", "হ্যাঁ", "না", "দয়া", "ধন্যবাদ",
55
+ "কেমন", "কি", "কোথায়", "কখন", "কেন", "কে",
56
+ "আজ", "আগামীকাল", "এখন", "সবসময়", "কখনো", "খুব",
57
+ "বাংলা", "ভারত", "ইংরেজি",
58
+ "সুন্দর", "চমৎকার", "দারুণ",
59
+ "তোমাকে", "আপনাকে", "তাকে", "আমাকে",
60
+ "করছি", "করছে", "হচ্ছে", "পড়ি", "লিখি", "বলি",
61
+ "আছ", "সকাল", "জানে", "দেখে",
62
+ ]
63
+
64
+
65
+ class Vocab:
66
+ def __init__(self, words: List[str], name: str = "vocab"):
67
+ self.name = name
68
+ self.token2idx: Dict[str, int] = {
69
+ PAD_TOKEN: PAD_IDX,
70
+ BOS_TOKEN: BOS_IDX,
71
+ EOS_TOKEN: EOS_IDX,
72
+ UNK_TOKEN: UNK_IDX,
73
+ }
74
+ for w in words:
75
+ if w not in self.token2idx:
76
+ self.token2idx[w] = len(self.token2idx)
77
+ self.idx2token: Dict[int, str] = {v: k for k, v in self.token2idx.items()}
78
+
79
+ def __len__(self):
80
+ return len(self.token2idx)
81
+
82
+ def encode(self, sentence: str) -> List[int]:
83
+ tokens = sentence.lower().strip().split()
84
+ ids = [self.token2idx.get(t, UNK_IDX) for t in tokens]
85
+ return [BOS_IDX] + ids + [EOS_IDX]
86
+
87
+ def decode(self, ids: List[int], skip_special: bool = True) -> str:
88
+ skip = {PAD_IDX, BOS_IDX, EOS_IDX} if skip_special else set()
89
+ return " ".join(
90
+ self.idx2token.get(i, UNK_TOKEN)
91
+ for i in ids
92
+ if i not in skip
93
+ )
94
+
95
+ def tokens(self, ids: List[int]) -> List[str]:
96
+ return [self.idx2token.get(i, UNK_TOKEN) for i in ids]
97
+
98
+
99
+ # ── Toy parallel corpus ──────────────────────────────────────────────────────
100
+
101
+ PARALLEL_DATA: List[Tuple[str, str]] = [
102
+ ("i love you", "আমি তোমাকে ভালোবাসি"),
103
+ ("i like food", "আমি খাবার পছন্দ"),
104
+ ("you are good", "তুমি ভালো"),
105
+ ("he is happy", "সে খুশি"),
106
+ ("i want water", "আমি পানি চাই"),
107
+ ("she is beautiful", "সে সুন্দর"),
108
+ ("i eat rice", "আমি ভাত খাই"),
109
+ ("i drink tea", "আমি চা পান"),
110
+ ("i know you", "আমি তোমাকে জানি"),
111
+ ("we are happy", "আমরা খুশি"),
112
+ ("i see you", "আমি তোমাকে দেখি"),
113
+ ("you are beautiful", "তুমি সুন্দর"),
114
+ ("i love bengali", "আমি বাংলা ভালোবাসি"),
115
+ ("hello how are you", "হ্যালো কেমন আছ"), # আছ added to vocab
116
+ ("thank you very much", "ধন্যবাদ খুব"),
117
+ ("i need you", "আমি তোমাকে দরকার"),
118
+ ("he knows bengali", "সে বাংলা জানে"), # জানে added to vocab
119
+ ("she drinks milk", "সে দুধ পান"),
120
+ ("we go home", "আমরা বাড়ি যাই"),
121
+ ("i am happy today", "আমি আজ খুশি"),
122
+ ("i love my home", "আমি আমার বাড়ি ভালোবাসি"),
123
+ ("she sees you", "সে তোমাকে দেখে"),
124
+ ("they eat rice", "তারা ভাত খাই"),
125
+ ("i want tea", "আমি চা চাই"),
126
+ ("this is good", "এটা ভালো"),
127
+ ("he is sad", "সে দুঃখী"),
128
+ ("i like school", "আমি স্কুল পছন্দ"),
129
+ ("you know me", "তুমি আমাকে জানে"),
130
+ ("good morning", "ভালো সকাল"),
131
+ ("i love india", "আমি ভারত ভালোবাসি"),
132
+ ]
133
+
134
+
135
+ def build_vocabs() -> Tuple[Vocab, Vocab]:
136
+ src_v = Vocab(EN_WORDS, "english")
137
+ tgt_v = Vocab(BN_WORDS, "bengali")
138
+ return src_v, tgt_v
139
+
140
+
141
+ # singleton
142
+ _src_vocab, _tgt_vocab = build_vocabs()
143
+
144
+
145
+ def get_vocabs() -> Tuple[Vocab, Vocab]:
146
+ return _src_vocab, _tgt_vocab