Spaces:
Running
Running
priyadip commited on
Commit ·
dc138e1
0
Parent(s):
Fix: js in gr.Blocks(), event delegation for card clicks, SVG loss curve
Browse files- README.md +91 -0
- app.py +800 -0
- inference.py +250 -0
- requirements.txt +2 -0
- training.py +261 -0
- transformer.py +516 -0
- vocab.py +146 -0
README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Transformer Visualizer EN→BN
|
| 3 |
+
emoji: 🔬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.9.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: true
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# 🔬 Transformer Visualizer — English → Bengali
|
| 14 |
+
|
| 15 |
+
**See every single calculation inside a Transformer, live.**
|
| 16 |
+
|
| 17 |
+
## What this Space does
|
| 18 |
+
|
| 19 |
+
Type any English sentence and watch every number flow through the Transformer architecture step by step — from raw token IDs all the way to Bengali output.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## 🗂️ Tabs
|
| 24 |
+
|
| 25 |
+
### 🏗️ Architecture
|
| 26 |
+
- Full SVG diagram of encoder + decoder
|
| 27 |
+
- Color-coded: self-attention / cross-attention / masked attention / FFN
|
| 28 |
+
- Explains K,V flow from encoder to decoder
|
| 29 |
+
|
| 30 |
+
### 🏋️ Train Model
|
| 31 |
+
- Trains a small Transformer on 30 English→Bengali sentence pairs
|
| 32 |
+
- Live loss curve rendered on canvas
|
| 33 |
+
- Configurable epochs
|
| 34 |
+
|
| 35 |
+
### 🔬 Training Step
|
| 36 |
+
Shows a **single training forward pass** with teacher forcing:
|
| 37 |
+
|
| 38 |
+
1. **Tokenization** — English + Bengali → token ID arrays
|
| 39 |
+
2. **Embedding** — `token_id → vector × √d_model`
|
| 40 |
+
3. **Positional Encoding** — `sin(pos/10000^(2i/d))` / `cos(...)` matrix shown
|
| 41 |
+
4. **Encoder**:
|
| 42 |
+
- Q, K, V projection matrices shown
|
| 43 |
+
- `scores = Q·Kᵀ / √d_k` with actual numbers
|
| 44 |
+
- Softmax attention weights (heatmap)
|
| 45 |
+
- Residual + LayerNorm
|
| 46 |
+
- FFN: `max(0, xW₁+b₁)W₂+b₂`
|
| 47 |
+
5. **Decoder**:
|
| 48 |
+
- Masked self-attention with causal mask matrix
|
| 49 |
+
- Cross-attention: Q from decoder, K/V from encoder
|
| 50 |
+
6. **Loss** — label-smoothed cross-entropy, gradient norms, Adam update
|
| 51 |
+
|
| 52 |
+
### ⚡ Inference
|
| 53 |
+
Shows **auto-regressive decoding**:
|
| 54 |
+
|
| 55 |
+
- No ground truth needed
|
| 56 |
+
- Token generated one at a time
|
| 57 |
+
- Top-5 candidates + probabilities at every step
|
| 58 |
+
- Cross-attention heatmap: which Bengali token attends to which English word
|
| 59 |
+
- Greedy vs Beam Search comparison
|
| 60 |
+
|
| 61 |
+
---
|
| 62 |
+
|
| 63 |
+
## 📁 File Structure
|
| 64 |
+
|
| 65 |
+
```
|
| 66 |
+
app.py — Gradio UI + HTML/CSS/JS rendering
|
| 67 |
+
transformer.py — Full Transformer with CalcLog hooks
|
| 68 |
+
training.py — Training loop + single-step visualization
|
| 69 |
+
inference.py — Greedy & beam search with logging
|
| 70 |
+
vocab.py — English/Bengali vocabularies + parallel corpus
|
| 71 |
+
requirements.txt
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## ⚙️ Model Config
|
| 77 |
+
|
| 78 |
+
| Parameter | Value |
|
| 79 |
+
|-----------|-------|
|
| 80 |
+
| d_model | 64 |
|
| 81 |
+
| num_heads | 4 |
|
| 82 |
+
| num_layers | 2 |
|
| 83 |
+
| d_ff | 128 |
|
| 84 |
+
| vocab (EN) | ~100 |
|
| 85 |
+
| vocab (BN) | ~90 |
|
| 86 |
+
| Optimizer | Adam |
|
| 87 |
+
| Loss | Label-smoothed CE |
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
*Built for educational purposes — every matrix operation is logged and displayed.*
|
app.py
ADDED
|
@@ -0,0 +1,800 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
app.py
|
| 3 |
+
Gradio Space: Interactive Transformer Visualizer — English → Bengali
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from transformer import Transformer, CalcLog
|
| 14 |
+
from training import build_model, run_training, visualize_training_step, collate_batch
|
| 15 |
+
from inference import visualize_inference
|
| 16 |
+
from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX
|
| 17 |
+
|
| 18 |
+
# ─────────────────────────────────────────────
|
| 19 |
+
# Global state
|
| 20 |
+
# ─────────────────────────────────────────────
|
| 21 |
+
DEVICE = "cpu"
|
| 22 |
+
src_v, tgt_v = get_vocabs()
|
| 23 |
+
MODEL: Transformer = None
|
| 24 |
+
LOSS_HISTORY = []
|
| 25 |
+
IS_TRAINED = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def get_or_init_model():
|
| 29 |
+
global MODEL
|
| 30 |
+
if MODEL is None:
|
| 31 |
+
MODEL = build_model(len(src_v), len(tgt_v), DEVICE)
|
| 32 |
+
return MODEL
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─────────────────────────────────────────────
|
| 36 |
+
# HTML renderer for calc log
|
| 37 |
+
# ─────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
def render_matrix_html(val, max_rows=6, max_cols=8):
|
| 40 |
+
"""Convert a nested list / scalar to an HTML matrix table."""
|
| 41 |
+
if isinstance(val, (int, float)):
|
| 42 |
+
return f'<span class="scalar-val">{val:.5f}</span>'
|
| 43 |
+
if isinstance(val, dict):
|
| 44 |
+
rows = "".join(
|
| 45 |
+
f'<tr><td class="dict-key">{k}</td><td class="dict-val">{v}</td></tr>'
|
| 46 |
+
for k, v in val.items()
|
| 47 |
+
)
|
| 48 |
+
return f'<table class="dict-table">{rows}</table>'
|
| 49 |
+
if isinstance(val, list):
|
| 50 |
+
# 0-D or scalar list
|
| 51 |
+
if len(val) == 0:
|
| 52 |
+
return "<em>empty</em>"
|
| 53 |
+
# 1-D
|
| 54 |
+
if not isinstance(val[0], list):
|
| 55 |
+
clipped = val[:max_cols*2]
|
| 56 |
+
cells = "".join(
|
| 57 |
+
f'<td class="mat-cell">{v:.4f}</td>'
|
| 58 |
+
if isinstance(v, float) else f'<td class="mat-cell">{v}</td>'
|
| 59 |
+
for v in clipped
|
| 60 |
+
)
|
| 61 |
+
suffix = f'<td class="mat-more">…+{len(val)-len(clipped)}</td>' if len(val) > len(clipped) else ""
|
| 62 |
+
return f'<table class="matrix-1d"><tr>{cells}{suffix}</tr></table>'
|
| 63 |
+
# 2-D
|
| 64 |
+
rows_html = ""
|
| 65 |
+
display_rows = val[:max_rows]
|
| 66 |
+
for row in display_rows:
|
| 67 |
+
display_cols = row[:max_cols]
|
| 68 |
+
cells = "".join(
|
| 69 |
+
f'<td class="mat-cell" style="--v:{min(max(float(c),-1),1):.3f}">'
|
| 70 |
+
f'{float(c):.3f}</td>'
|
| 71 |
+
if isinstance(c, (int, float)) else f'<td class="mat-cell">{c}</td>'
|
| 72 |
+
for c in display_cols
|
| 73 |
+
)
|
| 74 |
+
suffix = f'<td class="mat-more">…</td>' if len(row) > max_cols else ""
|
| 75 |
+
rows_html += f"<tr>{cells}{suffix}</tr>"
|
| 76 |
+
if len(val) > max_rows:
|
| 77 |
+
rows_html += f'<tr><td colspan="{max_cols+1}" class="mat-more">…{len(val)-max_rows} more rows</td></tr>'
|
| 78 |
+
return f'<table class="matrix-2d">{rows_html}</table>'
|
| 79 |
+
return f'<code>{str(val)[:200]}</code>'
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def calc_log_to_html(steps):
|
| 83 |
+
"""Turn CalcLog steps into rich HTML accordion."""
|
| 84 |
+
if not steps:
|
| 85 |
+
return "<p style='color:#888'>No calculation log yet.</p>"
|
| 86 |
+
|
| 87 |
+
cards = []
|
| 88 |
+
for i, step in enumerate(steps):
|
| 89 |
+
name = step.get("name", f"step_{i}")
|
| 90 |
+
formula = step.get("formula", "")
|
| 91 |
+
note = step.get("note", "")
|
| 92 |
+
shape = step.get("shape")
|
| 93 |
+
val = step.get("value")
|
| 94 |
+
|
| 95 |
+
shape_badge = f'<span class="shape-badge">{shape}</span>' if shape else ""
|
| 96 |
+
formula_html = f'<div class="formula">⟨ {formula} ⟩</div>' if formula else ""
|
| 97 |
+
note_html = f'<div class="step-note">ℹ {note}</div>' if note else ""
|
| 98 |
+
matrix_html = render_matrix_html(val) if val is not None else ""
|
| 99 |
+
|
| 100 |
+
# Color category by name prefix
|
| 101 |
+
cat = "default"
|
| 102 |
+
n = name.upper()
|
| 103 |
+
if "EMBED" in n or "TOKEN" in n: cat = "embed"
|
| 104 |
+
elif "PE" in n or "POSITIONAL" in n: cat = "pe"
|
| 105 |
+
elif "SOFTMAX" in n or "ATTN" in n or "_Q" in n or "_K" in n or "_V" in n: cat = "attn"
|
| 106 |
+
elif "FFN" in n or "LINEAR" in n or "RELU" in n: cat = "ffn"
|
| 107 |
+
elif "NORM" in n or "RESIDUAL" in n: cat = "norm"
|
| 108 |
+
elif "LOSS" in n or "GRAD" in n or "OPTIM" in n: cat = "loss"
|
| 109 |
+
elif "INFERENCE" in n or "GREEDY" in n or "BEAM" in n: cat = "infer"
|
| 110 |
+
elif "CROSS" in n: cat = "cross"
|
| 111 |
+
elif "MASK" in n: cat = "mask"
|
| 112 |
+
|
| 113 |
+
cards.append(f"""
|
| 114 |
+
<div class="calc-card cat-{cat}" data-idx="{i}">
|
| 115 |
+
<div class="calc-header">
|
| 116 |
+
<span class="step-num">#{i+1}</span>
|
| 117 |
+
<span class="step-name cat-label-{cat}">{name.replace('_',' ')}</span>
|
| 118 |
+
{shape_badge}
|
| 119 |
+
<span class="toggle-arrow">▶</span>
|
| 120 |
+
</div>
|
| 121 |
+
<div class="calc-body" style="display:none">
|
| 122 |
+
{formula_html}
|
| 123 |
+
{note_html}
|
| 124 |
+
<div class="matrix-wrap">{matrix_html}</div>
|
| 125 |
+
</div>
|
| 126 |
+
</div>""")
|
| 127 |
+
|
| 128 |
+
return "\n".join(cards)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ─────────────────────────────────────────────
|
| 132 |
+
# Attention heatmap HTML
|
| 133 |
+
# ─────────────────────────────────────────────
|
| 134 |
+
|
| 135 |
+
def attention_heatmap_html(weights, row_labels, col_labels, title="Attention"):
|
| 136 |
+
"""weights: 2D list [tgt, src]"""
|
| 137 |
+
if not weights:
|
| 138 |
+
return ""
|
| 139 |
+
rows_html = ""
|
| 140 |
+
for i, row in enumerate(weights):
|
| 141 |
+
cells = ""
|
| 142 |
+
for j, w in enumerate(row):
|
| 143 |
+
alpha = min(float(w), 1.0)
|
| 144 |
+
cells += f'<td class="heat-cell" style="--a:{alpha:.3f}" title="{row_labels[i] if i<len(row_labels) else i}→{col_labels[j] if j<len(col_labels) else j}: {alpha:.3f}">{alpha:.2f}</td>'
|
| 145 |
+
lbl = row_labels[i] if i < len(row_labels) else str(i)
|
| 146 |
+
rows_html += f'<tr><td class="heat-label">{lbl}</td>{cells}</tr>'
|
| 147 |
+
header = '<tr><td></td>' + "".join(f'<td class="heat-col-label">{c}</td>' for c in col_labels) + '</tr>'
|
| 148 |
+
return f"""
|
| 149 |
+
<div class="heatmap-container">
|
| 150 |
+
<div class="heatmap-title">{title}</div>
|
| 151 |
+
<table class="heatmap">{header}{rows_html}</table>
|
| 152 |
+
</div>"""
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ─────────────────────────────────────────────
|
| 156 |
+
# Decoding steps HTML
|
| 157 |
+
# ─────────────────────────────────────────────
|
| 158 |
+
|
| 159 |
+
def decode_steps_html(step_logs, src_tokens):
|
| 160 |
+
if not step_logs:
|
| 161 |
+
return ""
|
| 162 |
+
html = '<div class="decode-steps"><div class="decode-title">🔁 Auto-regressive Decoding Steps</div>'
|
| 163 |
+
for s in step_logs:
|
| 164 |
+
step = s.get("step", 0)
|
| 165 |
+
tokens_so_far = s.get("tokens_so_far", [])
|
| 166 |
+
top5 = s.get("top5", [])
|
| 167 |
+
chosen = s.get("chosen_token", "?")
|
| 168 |
+
prob = s.get("chosen_prob", 0)
|
| 169 |
+
|
| 170 |
+
bars = ""
|
| 171 |
+
if top5:
|
| 172 |
+
max_p = max(t["prob"] for t in top5) or 1
|
| 173 |
+
for t in top5:
|
| 174 |
+
pct = t["prob"] / max_p * 100
|
| 175 |
+
is_chosen = "chosen" if t["token"] == chosen else ""
|
| 176 |
+
bars += f"""<div class="bar-row {is_chosen}">
|
| 177 |
+
<span class="bar-label">{t['token']}</span>
|
| 178 |
+
<div class="bar" style="width:{pct:.1f}%"></div>
|
| 179 |
+
<span class="bar-prob">{t['prob']:.3f}</span>
|
| 180 |
+
</div>"""
|
| 181 |
+
|
| 182 |
+
cross_heat = ""
|
| 183 |
+
if s.get("cross_attn") and src_tokens:
|
| 184 |
+
attn_mat = s["cross_attn"] # [num_heads][T_q][T_src]
|
| 185 |
+
if attn_mat and attn_mat[0]:
|
| 186 |
+
# Take head-0, last decoded position → [T_src] floats
|
| 187 |
+
last_pos_attn = attn_mat[0][-1] # [T_src]
|
| 188 |
+
last_row = [last_pos_attn] # [[T_src]] — 2D for heatmap
|
| 189 |
+
cross_heat = attention_heatmap_html(
|
| 190 |
+
last_row, [chosen], src_tokens,
|
| 191 |
+
title=f"Cross-Attn: '{chosen}' → English"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
html += f"""
|
| 195 |
+
<div class="decode-step">
|
| 196 |
+
<div class="decode-step-header">
|
| 197 |
+
<span class="step-badge">Step {step+1}</span>
|
| 198 |
+
<span class="step-ctx">Context: {' '.join(tokens_so_far)}</span>
|
| 199 |
+
<span class="step-arrow">→</span>
|
| 200 |
+
<span class="step-chosen">'{chosen}'</span>
|
| 201 |
+
<span class="step-prob">{prob:.3f}</span>
|
| 202 |
+
</div>
|
| 203 |
+
<div class="step-bars">{bars}</div>
|
| 204 |
+
{cross_heat}
|
| 205 |
+
</div>"""
|
| 206 |
+
html += "</div>"
|
| 207 |
+
return html
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ─────────────────────────────────────────────
|
| 211 |
+
# Architecture SVG
|
| 212 |
+
# ─────────────────────────────────────────────
|
| 213 |
+
|
| 214 |
+
ARCH_SVG = """
|
| 215 |
+
<div id="arch-diagram">
|
| 216 |
+
<svg viewBox="0 0 820 900" xmlns="http://www.w3.org/2000/svg" style="width:100%;max-width:820px;margin:auto;display:block">
|
| 217 |
+
<defs>
|
| 218 |
+
<marker id="arr" markerWidth="8" markerHeight="8" refX="6" refY="3" orient="auto">
|
| 219 |
+
<path d="M0,0 L0,6 L8,3 z" fill="#64ffda"/>
|
| 220 |
+
</marker>
|
| 221 |
+
<filter id="glow">
|
| 222 |
+
<feGaussianBlur stdDeviation="2" result="blur"/>
|
| 223 |
+
<feMerge><feMergeNode in="blur"/><feMergeNode in="SourceGraphic"/></feMerge>
|
| 224 |
+
</filter>
|
| 225 |
+
</defs>
|
| 226 |
+
|
| 227 |
+
<!-- Background -->
|
| 228 |
+
<rect width="820" height="900" fill="#0a0f1e" rx="12"/>
|
| 229 |
+
|
| 230 |
+
<!-- Title -->
|
| 231 |
+
<text x="410" y="35" text-anchor="middle" fill="#64ffda" font-size="16" font-family="monospace" font-weight="bold">Transformer Architecture — English → Bengali</text>
|
| 232 |
+
|
| 233 |
+
<!-- ── ENCODER (left) ── -->
|
| 234 |
+
<rect x="40" y="60" width="330" height="720" rx="10" fill="#0d1b2a" stroke="#1e4d6b" stroke-width="1.5"/>
|
| 235 |
+
<text x="205" y="90" text-anchor="middle" fill="#4fc3f7" font-size="13" font-weight="bold">ENCODER</text>
|
| 236 |
+
|
| 237 |
+
<!-- Input Embedding -->
|
| 238 |
+
<rect x="70" y="110" width="270" height="40" rx="6" fill="#1a3a5c" stroke="#4fc3f7" stroke-width="1.5"/>
|
| 239 |
+
<text x="205" y="135" text-anchor="middle" fill="#e0f7fa" font-size="11">Input Embedding + Positional Encoding</text>
|
| 240 |
+
|
| 241 |
+
<!-- Encoder Layer Box -->
|
| 242 |
+
<rect x="60" y="175" width="290" height="340" rx="8" fill="#112233" stroke="#1e4d6b" stroke-width="1" stroke-dasharray="4"/>
|
| 243 |
+
<text x="100" y="198" fill="#607d8b" font-size="10">Encoder Layer × N</text>
|
| 244 |
+
|
| 245 |
+
<!-- Multi-Head Self-Attention -->
|
| 246 |
+
<rect x="80" y="210" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
|
| 247 |
+
<text x="205" y="232" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Multi-Head Self-Attention</text>
|
| 248 |
+
<text x="205" y="248" text-anchor="middle" fill="#80deea" font-size="9">Q = K = V = encoder input</text>
|
| 249 |
+
|
| 250 |
+
<!-- Add & Norm 1 -->
|
| 251 |
+
<rect x="80" y="278" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
|
| 252 |
+
<text x="205" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text>
|
| 253 |
+
|
| 254 |
+
<!-- FFN -->
|
| 255 |
+
<rect x="80" y="328" width="250" height="50" rx="6" fill="#1b3a4b" stroke="#26c6da" stroke-width="1.5"/>
|
| 256 |
+
<text x="205" y="350" text-anchor="middle" fill="#e0f7fa" font-size="11" font-weight="bold">Feed-Forward Network</text>
|
| 257 |
+
<text x="205" y="366" text-anchor="middle" fill="#80deea" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
|
| 258 |
+
|
| 259 |
+
<!-- Add & Norm 2 -->
|
| 260 |
+
<rect x="80" y="396" width="250" height="30" rx="5" fill="#1a2a3a" stroke="#607d8b" stroke-width="1"/>
|
| 261 |
+
<text x="205" y="416" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text>
|
| 262 |
+
|
| 263 |
+
<!-- Encoder output arrow down -->
|
| 264 |
+
<line x1="205" y1="455" x2="205" y2="550" stroke="#64ffda" stroke-width="1.5" marker-end="url(#arr)"/>
|
| 265 |
+
<text x="215" y="510" fill="#64ffda" font-size="9">K, V to</text>
|
| 266 |
+
<text x="215" y="522" fill="#64ffda" font-size="9">decoder</text>
|
| 267 |
+
|
| 268 |
+
<!-- Encoder output box -->
|
| 269 |
+
<rect x="70" y="555" width="270" height="40" rx="6" fill="#0d2b1a" stroke="#00e676" stroke-width="1.5"/>
|
| 270 |
+
<text x="205" y="580" text-anchor="middle" fill="#a5d6a7" font-size="11">Encoder Output (K, V)</text>
|
| 271 |
+
|
| 272 |
+
<!-- ── DECODER (right) ── -->
|
| 273 |
+
<rect x="450" y="60" width="330" height="720" rx="10" fill="#1a0d2a" stroke="#4a1b6b" stroke-width="1.5"/>
|
| 274 |
+
<text x="615" y="90" text-anchor="middle" fill="#ce93d8" font-size="13" font-weight="bold">DECODER</text>
|
| 275 |
+
|
| 276 |
+
<!-- Target Embedding -->
|
| 277 |
+
<rect x="480" y="110" width="270" height="40" rx="6" fill="#3a1a5c" stroke="#ce93d8" stroke-width="1.5"/>
|
| 278 |
+
<text x="615" y="135" text-anchor="middle" fill="#f3e5f5" font-size="11">Target Embedding + Positional Encoding</text>
|
| 279 |
+
|
| 280 |
+
<!-- Decoder Layer Box -->
|
| 281 |
+
<rect x="470" y="175" width="290" height="460" rx="8" fill="#1a1133" stroke="#4a1b6b" stroke-width="1" stroke-dasharray="4"/>
|
| 282 |
+
<text x="510" y="198" fill="#607d8b" font-size="10">Decoder Layer × N</text>
|
| 283 |
+
|
| 284 |
+
<!-- Masked MHA -->
|
| 285 |
+
<rect x="490" y="210" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
|
| 286 |
+
<text x="615" y="232" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Masked Multi-Head Self-Attention</text>
|
| 287 |
+
<text x="615" y="248" text-anchor="middle" fill="#ce93d8" font-size="9">Q = K = V = decoder input (causal mask)</text>
|
| 288 |
+
|
| 289 |
+
<!-- Add & Norm D1 -->
|
| 290 |
+
<rect x="490" y="278" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
|
| 291 |
+
<text x="615" y="298" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text>
|
| 292 |
+
|
| 293 |
+
<!-- Cross-Attention -->
|
| 294 |
+
<rect x="490" y="328" width="250" height="60" rx="6" fill="#1b2b4b" stroke="#29b6f6" stroke-width="2" filter="url(#glow)"/>
|
| 295 |
+
<text x="615" y="350" text-anchor="middle" fill="#e1f5fe" font-size="11" font-weight="bold">Cross-Attention</text>
|
| 296 |
+
<text x="615" y="366" text-anchor="middle" fill="#81d4fa" font-size="9">Q = decoder | K, V = encoder</text>
|
| 297 |
+
<text x="615" y="380" text-anchor="middle" fill="#29b6f6" font-size="9" font-weight="bold">← KEY CONNECTION</text>
|
| 298 |
+
|
| 299 |
+
<!-- Add & Norm D2 -->
|
| 300 |
+
<rect x="490" y="408" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
|
| 301 |
+
<text x="615" y="428" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text>
|
| 302 |
+
|
| 303 |
+
<!-- FFN Decoder -->
|
| 304 |
+
<rect x="490" y="458" width="250" height="50" rx="6" fill="#2b1b3a" stroke="#ab47bc" stroke-width="1.5"/>
|
| 305 |
+
<text x="615" y="480" text-anchor="middle" fill="#f3e5f5" font-size="11" font-weight="bold">Feed-Forward Network</text>
|
| 306 |
+
<text x="615" y="496" text-anchor="middle" fill="#ce93d8" font-size="9">FFN(x) = max(0, xW₁+b₁)W₂+b₂</text>
|
| 307 |
+
|
| 308 |
+
<!-- Add & Norm D3 -->
|
| 309 |
+
<rect x="490" y="526" width="250" height="30" rx="5" fill="#2a1a3a" stroke="#607d8b" stroke-width="1"/>
|
| 310 |
+
<text x="615" y="546" text-anchor="middle" fill="#b0bec5" font-size="10">Add & Norm</text>
|
| 311 |
+
|
| 312 |
+
<!-- Output Linear + Softmax -->
|
| 313 |
+
<rect x="480" y="600" width="270" height="40" rx="6" fill="#2b1b0a" stroke="#ffb300" stroke-width="1.5"/>
|
| 314 |
+
<text x="615" y="625" text-anchor="middle" fill="#fff8e1" font-size="11">Linear + Softmax → Bengali Token</text>
|
| 315 |
+
|
| 316 |
+
<!-- Cross-attention arrow from encoder to decoder -->
|
| 317 |
+
<path d="M340,590 Q410,480 490,368" stroke="#29b6f6" stroke-width="2" fill="none"
|
| 318 |
+
stroke-dasharray="6,3" marker-end="url(#arr)"/>
|
| 319 |
+
<text x="390" y="500" fill="#29b6f6" font-size="9" transform="rotate(-50,390,500)">K, V flow</text>
|
| 320 |
+
|
| 321 |
+
<!-- Input arrow -->
|
| 322 |
+
<line x1="205" y1="840" x2="205" y2="780" stroke="#4fc3f7" stroke-width="1.5" marker-end="url(#arr)"/>
|
| 323 |
+
<text x="205" y="858" text-anchor="middle" fill="#4fc3f7" font-size="11">English Input</text>
|
| 324 |
+
|
| 325 |
+
<line x1="615" y1="840" x2="615" y2="660" stroke="#ce93d8" stroke-width="1.5" marker-end="url(#arr)"/>
|
| 326 |
+
<text x="615" y="858" text-anchor="middle" fill="#ce93d8" font-size="11">Bengali Output</text>
|
| 327 |
+
|
| 328 |
+
<!-- Legend -->
|
| 329 |
+
<rect x="60" y="870" width="700" height="20" rx="4" fill="#0a1520" stroke="#1e2d3d" stroke-width="1"/>
|
| 330 |
+
<circle cx="80" cy="880" r="4" fill="#26c6da"/><text x="88" y="884" fill="#80deea" font-size="8">Self-Attention</text>
|
| 331 |
+
<circle cx="160" cy="880" r="4" fill="#29b6f6"/><text x="168" y="884" fill="#81d4fa" font-size="8">Cross-Attention</text>
|
| 332 |
+
<circle cx="250" cy="880" r="4" fill="#ab47bc"/><text x="258" y="884" fill="#ce93d8" font-size="8">Masked Attn</text>
|
| 333 |
+
<circle cx="350" cy="880" r="4" fill="#00e676"/><text x="358" y="884" fill="#a5d6a7" font-size="8">Enc→Dec K,V</text>
|
| 334 |
+
<circle cx="450" cy="880" r="4" fill="#ffb300"/><text x="458" y="884" fill="#fff8e1" font-size="8">Output Layer</text>
|
| 335 |
+
</svg>
|
| 336 |
+
</div>
|
| 337 |
+
"""
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ─────────────────────────────────────────────
|
| 341 |
+
# CSS + JS
|
| 342 |
+
# ─────────────────────────────────────────────
|
| 343 |
+
|
| 344 |
+
CUSTOM_CSS = """
|
| 345 |
+
/* ── fonts ── */
|
| 346 |
+
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;600&family=Syne:wght@400;700;800&display=swap');
|
| 347 |
+
|
| 348 |
+
:root {
|
| 349 |
+
--bg: #07090f;
|
| 350 |
+
--bg2: #0d1120;
|
| 351 |
+
--bg3: #111827;
|
| 352 |
+
--card: #141c2e;
|
| 353 |
+
--border: #1e2d45;
|
| 354 |
+
--accent: #64ffda;
|
| 355 |
+
--accent2: #29b6f6;
|
| 356 |
+
--accent3: #ce93d8;
|
| 357 |
+
--accent4: #ffb300;
|
| 358 |
+
--text: #e2e8f0;
|
| 359 |
+
--muted: #64748b;
|
| 360 |
+
--embed: #4fc3f7;
|
| 361 |
+
--pe: #26c6da;
|
| 362 |
+
--attn: #f06292;
|
| 363 |
+
--ffn: #aed581;
|
| 364 |
+
--norm: #90a4ae;
|
| 365 |
+
--loss: #ef9a9a;
|
| 366 |
+
--infer: #80cbc4;
|
| 367 |
+
--cross: #29b6f6;
|
| 368 |
+
--mask: #ffb300;
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
body, .gradio-container { background: var(--bg) !important; color: var(--text) !important; font-family: 'JetBrains Mono', monospace !important; }
|
| 372 |
+
|
| 373 |
+
h1, h2, h3 { font-family: 'Syne', sans-serif !important; }
|
| 374 |
+
|
| 375 |
+
/* ── tabs ── */
|
| 376 |
+
.tab-nav button { background: var(--bg3) !important; color: var(--muted) !important; border: 1px solid var(--border) !important; font-family: 'JetBrains Mono', monospace !important; letter-spacing: 1px; }
|
| 377 |
+
.tab-nav button.selected { background: var(--card) !important; color: var(--accent) !important; border-color: var(--accent) !important; box-shadow: 0 0 8px rgba(100,255,218,0.2); }
|
| 378 |
+
|
| 379 |
+
/* ── inputs ── */
|
| 380 |
+
input[type=text], textarea { background: var(--bg3) !important; color: var(--text) !important; border: 1px solid var(--border) !important; border-radius: 6px !important; font-family: 'JetBrains Mono', monospace !important; }
|
| 381 |
+
input[type=text]:focus, textarea:focus { border-color: var(--accent) !important; box-shadow: 0 0 6px rgba(100,255,218,0.2) !important; }
|
| 382 |
+
|
| 383 |
+
button.primary { background: linear-gradient(135deg, #0d3d30, #0d3d4d) !important; color: var(--accent) !important; border: 1px solid var(--accent) !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 600 !important; letter-spacing: 1px; transition: all 0.2s; }
|
| 384 |
+
button.primary:hover { background: linear-gradient(135deg, #1a5c4a, #1a5c6d) !important; box-shadow: 0 0 12px rgba(100,255,218,0.3) !important; }
|
| 385 |
+
|
| 386 |
+
/* ── calc cards ── */
|
| 387 |
+
.calc-card { border-radius: 8px; margin: 4px 0; border: 1px solid var(--border); background: var(--card); overflow: hidden; }
|
| 388 |
+
.calc-header { display: flex; align-items: center; gap: 8px; padding: 8px 12px; cursor: pointer; user-select: none; transition: background 0.15s; }
|
| 389 |
+
.calc-header:hover { background: rgba(255,255,255,0.03); }
|
| 390 |
+
.calc-body { padding: 10px 14px; background: var(--bg2); border-top: 1px solid var(--border); }
|
| 391 |
+
.step-num { color: var(--muted); font-size: 11px; min-width: 28px; }
|
| 392 |
+
.step-name { font-weight: 600; font-size: 12px; flex: 1; }
|
| 393 |
+
.toggle-arrow { color: var(--muted); font-size: 10px; transition: transform 0.2s; }
|
| 394 |
+
.toggle-arrow.open { transform: rotate(90deg); }
|
| 395 |
+
.shape-badge { background: var(--bg3); color: var(--muted); font-size: 10px; padding: 1px 6px; border-radius: 4px; border: 1px solid var(--border); }
|
| 396 |
+
.formula { color: var(--accent); font-size: 11px; font-style: italic; margin-bottom: 4px; background: rgba(100,255,218,0.05); padding: 4px 8px; border-radius: 4px; border-left: 2px solid var(--accent); }
|
| 397 |
+
.step-note { color: var(--muted); font-size: 11px; margin-bottom: 6px; }
|
| 398 |
+
|
| 399 |
+
/* category colors */
|
| 400 |
+
.cat-label-embed { color: var(--embed); }
|
| 401 |
+
.cat-label-pe { color: var(--pe); }
|
| 402 |
+
.cat-label-attn { color: var(--attn); }
|
| 403 |
+
.cat-label-ffn { color: var(--ffn); }
|
| 404 |
+
.cat-label-norm { color: var(--norm); }
|
| 405 |
+
.cat-label-loss { color: var(--loss); }
|
| 406 |
+
.cat-label-infer { color: var(--infer); }
|
| 407 |
+
.cat-label-cross { color: var(--cross); }
|
| 408 |
+
.cat-label-mask { color: var(--mask); }
|
| 409 |
+
.cat-label-default{ color: var(--text); }
|
| 410 |
+
|
| 411 |
+
.cat-embed { border-left: 3px solid var(--embed); }
|
| 412 |
+
.cat-pe { border-left: 3px solid var(--pe); }
|
| 413 |
+
.cat-attn { border-left: 3px solid var(--attn); }
|
| 414 |
+
.cat-ffn { border-left: 3px solid var(--ffn); }
|
| 415 |
+
.cat-norm { border-left: 3px solid var(--norm); }
|
| 416 |
+
.cat-loss { border-left: 3px solid var(--loss); }
|
| 417 |
+
.cat-infer { border-left: 3px solid var(--infer); }
|
| 418 |
+
.cat-cross { border-left: 3px solid var(--cross); }
|
| 419 |
+
.cat-mask { border-left: 3px solid var(--mask); }
|
| 420 |
+
.cat-default{ border-left: 3px solid var(--border); }
|
| 421 |
+
|
| 422 |
+
/* ── matrix tables ── */
|
| 423 |
+
.matrix-wrap { overflow-x: auto; }
|
| 424 |
+
.matrix-2d, .matrix-1d { border-collapse: collapse; font-size: 10px; font-family: 'JetBrains Mono', monospace; }
|
| 425 |
+
.mat-cell {
|
| 426 |
+
padding: 2px 5px; text-align: right; min-width: 48px;
|
| 427 |
+
background: color-mix(in srgb, #29b6f6 calc((var(--v,0) + 1) * 30%), #0d1120 calc(100% - (var(--v,0) + 1) * 30%));
|
| 428 |
+
color: #e2e8f0; border: 1px solid rgba(255,255,255,0.05);
|
| 429 |
+
}
|
| 430 |
+
.mat-more { color: var(--muted); font-style: italic; font-size: 9px; padding: 2px 6px; }
|
| 431 |
+
.dict-table { font-size: 11px; width: 100%; }
|
| 432 |
+
.dict-key { color: var(--accent); padding: 2px 8px 2px 0; }
|
| 433 |
+
.dict-val { color: var(--text); padding: 2px; }
|
| 434 |
+
.scalar-val { color: var(--accent4); font-size: 13px; font-weight: 600; }
|
| 435 |
+
|
| 436 |
+
/* ── heatmap ── */
|
| 437 |
+
.heatmap-container { margin: 8px 0; }
|
| 438 |
+
.heatmap-title { color: var(--accent2); font-size: 11px; margin-bottom: 4px; font-weight: 600; }
|
| 439 |
+
.heatmap { border-collapse: collapse; font-size: 10px; }
|
| 440 |
+
.heat-cell {
|
| 441 |
+
width: 36px; height: 24px; text-align: center;
|
| 442 |
+
background: rgba(41, 182, 246, calc(var(--a, 0)));
|
| 443 |
+
border: 1px solid rgba(255,255,255,0.04);
|
| 444 |
+
color: color-mix(in srgb, #fff calc(var(--a,0)*100%), #4a5568 calc(100% - var(--a,0)*100%));
|
| 445 |
+
font-size: 9px; cursor: default;
|
| 446 |
+
}
|
| 447 |
+
.heat-cell:hover { outline: 1px solid var(--accent); }
|
| 448 |
+
.heat-label { color: var(--accent3); font-size: 10px; padding-right: 6px; white-space: nowrap; }
|
| 449 |
+
.heat-col-label { color: var(--embed); font-size: 9px; text-align: center; padding-bottom: 2px; }
|
| 450 |
+
|
| 451 |
+
/* ── decode steps ── */
|
| 452 |
+
.decode-steps { margin-top: 12px; }
|
| 453 |
+
.decode-title { color: var(--accent); font-size: 13px; font-weight: 700; margin-bottom: 10px; padding-bottom: 4px; border-bottom: 1px solid var(--border); }
|
| 454 |
+
.decode-step { border: 1px solid var(--border); border-radius: 8px; margin: 6px 0; padding: 10px; background: var(--card); }
|
| 455 |
+
.decode-step-header { display: flex; align-items: center; gap: 8px; flex-wrap: wrap; margin-bottom: 8px; }
|
| 456 |
+
.step-badge { background: var(--accent); color: var(--bg); font-size: 10px; font-weight: 700; padding: 2px 8px; border-radius: 20px; }
|
| 457 |
+
.step-ctx { color: var(--muted); font-size: 11px; }
|
| 458 |
+
.step-arrow { color: var(--accent4); }
|
| 459 |
+
.step-chosen { color: var(--accent3); font-size: 13px; font-weight: 700; }
|
| 460 |
+
.step-prob { color: var(--accent4); font-size: 11px; }
|
| 461 |
+
.step-bars { margin: 4px 0; }
|
| 462 |
+
.bar-row { display: flex; align-items: center; gap: 6px; margin: 2px 0; }
|
| 463 |
+
.bar-row.chosen .bar-label { color: var(--accent3); font-weight: 700; }
|
| 464 |
+
.bar-row.chosen .bar { background: var(--accent3) !important; }
|
| 465 |
+
.bar-label { width: 100px; text-align: right; font-size: 11px; color: var(--text); white-space: nowrap; overflow: hidden; text-overflow: ellipsis; }
|
| 466 |
+
.bar { height: 14px; background: var(--accent2); border-radius: 2px; transition: width 0.4s; min-width: 2px; }
|
| 467 |
+
.bar-prob { font-size: 10px; color: var(--muted); }
|
| 468 |
+
|
| 469 |
+
/* ── loss chart ── */
|
| 470 |
+
#loss-chart-container { background: var(--bg2); border: 1px solid var(--border); border-radius: 8px; padding: 12px; margin-top: 8px; }
|
| 471 |
+
|
| 472 |
+
/* ── arch diagram ── */
|
| 473 |
+
#arch-diagram { background: var(--bg2); border: 1px solid var(--border); border-radius: 10px; padding: 12px; margin: 8px 0; }
|
| 474 |
+
|
| 475 |
+
/* ── result banner ── */
|
| 476 |
+
.result-banner { background: linear-gradient(135deg, #0d3d30, #1a1a3d); border: 1px solid var(--accent); border-radius: 10px; padding: 16px 20px; margin: 10px 0; }
|
| 477 |
+
.result-en { color: var(--embed); font-size: 14px; margin-bottom: 4px; }
|
| 478 |
+
.result-bn { color: var(--accent3); font-size: 22px; font-weight: 700; letter-spacing: 1px; }
|
| 479 |
+
.result-label { color: var(--muted); font-size: 10px; text-transform: uppercase; letter-spacing: 1px; }
|
| 480 |
+
|
| 481 |
+
/* ── misc ── */
|
| 482 |
+
.gradio-html { background: transparent !important; }
|
| 483 |
+
.panel { background: var(--card) !important; border: 1px solid var(--border) !important; border-radius: 10px !important; }
|
| 484 |
+
.log-container { max-height: 600px; overflow-y: auto; padding: 8px; scrollbar-width: thin; scrollbar-color: var(--border) transparent; }
|
| 485 |
+
"""
|
| 486 |
+
|
| 487 |
+
CUSTOM_JS = """
|
| 488 |
+
// Card toggle
|
| 489 |
+
window._toggleCard = function(header) {
|
| 490 |
+
const body = header.nextElementSibling;
|
| 491 |
+
const arrow = header.querySelector('.toggle-arrow');
|
| 492 |
+
if (!body) return;
|
| 493 |
+
const open = body.style.display === 'block';
|
| 494 |
+
body.style.display = open ? 'none' : 'block';
|
| 495 |
+
if (arrow) arrow.classList.toggle('open', !open);
|
| 496 |
+
};
|
| 497 |
+
window._expandAll = function() {
|
| 498 |
+
document.querySelectorAll('.calc-body').forEach(b => b.style.display='block');
|
| 499 |
+
document.querySelectorAll('.toggle-arrow').forEach(a => a.classList.add('open'));
|
| 500 |
+
};
|
| 501 |
+
window._collapseAll = function() {
|
| 502 |
+
document.querySelectorAll('.calc-body').forEach(b => b.style.display='none');
|
| 503 |
+
document.querySelectorAll('.toggle-arrow').forEach(a => a.classList.remove('open'));
|
| 504 |
+
};
|
| 505 |
+
window._filterCards = function(cat) {
|
| 506 |
+
document.querySelectorAll('.calc-card').forEach(c => {
|
| 507 |
+
c.style.display = (!cat || c.classList.contains('cat-'+cat)) ? '' : 'none';
|
| 508 |
+
});
|
| 509 |
+
};
|
| 510 |
+
// Event delegation — works even if Gradio strips onclick attrs
|
| 511 |
+
document.addEventListener('click', function(e) {
|
| 512 |
+
const header = e.target.closest('.calc-header');
|
| 513 |
+
if (header) { window._toggleCard(header); return; }
|
| 514 |
+
const btn = e.target.closest('[data-ga]');
|
| 515 |
+
if (btn) {
|
| 516 |
+
const a = btn.dataset.ga;
|
| 517 |
+
if (a === 'expand') window._expandAll();
|
| 518 |
+
else if (a === 'collapse') window._collapseAll();
|
| 519 |
+
else if (a.startsWith('filter:')) window._filterCards(a.slice(7));
|
| 520 |
+
}
|
| 521 |
+
}, true);
|
| 522 |
+
"""
|
| 523 |
+
|
| 524 |
+
# ─────────────────────────────────────────────
|
| 525 |
+
# Pure-SVG loss curve (no JS/canvas needed)
|
| 526 |
+
# ─────────────────────────────────────────────
|
| 527 |
+
|
| 528 |
+
def _loss_svg(losses):
|
| 529 |
+
if not losses:
|
| 530 |
+
return ""
|
| 531 |
+
W, H = 580, 200
|
| 532 |
+
pl, pr, pt, pb = 52, 16, 16, 36
|
| 533 |
+
pw, ph = W - pl - pr, H - pt - pb
|
| 534 |
+
mn, mx = min(losses), max(losses)
|
| 535 |
+
rng = mx - mn or 1
|
| 536 |
+
n = len(losses)
|
| 537 |
+
|
| 538 |
+
def px(i): return pl + (i / max(n - 1, 1)) * pw
|
| 539 |
+
def py(v): return pt + ph - ((v - mn) / rng) * ph
|
| 540 |
+
|
| 541 |
+
# Grid + Y labels
|
| 542 |
+
grid = ""
|
| 543 |
+
for k in range(5):
|
| 544 |
+
v = mn + (k / 4) * rng
|
| 545 |
+
y = py(v)
|
| 546 |
+
grid += f'<line x1="{pl}" y1="{y:.1f}" x2="{pl+pw}" y2="{y:.1f}" stroke="#1e2d45" stroke-width="0.5"/>'
|
| 547 |
+
grid += f'<text x="{pl-4}" y="{y+4:.1f}" text-anchor="end" fill="#64748b" font-size="9" font-family="monospace">{v:.3f}</text>'
|
| 548 |
+
|
| 549 |
+
# Polyline points
|
| 550 |
+
pts = " ".join(f"{px(i):.1f},{py(v):.1f}" for i, v in enumerate(losses))
|
| 551 |
+
fill_pts = f"{pl:.1f},{pt+ph:.1f} {pts} {pl+pw:.1f},{pt+ph:.1f}"
|
| 552 |
+
|
| 553 |
+
# X labels
|
| 554 |
+
xlabels = ""
|
| 555 |
+
for idx in ([0, n//4, n//2, 3*n//4, n-1] if n > 4 else range(n)):
|
| 556 |
+
xlabels += f'<text x="{px(idx):.1f}" y="{H-4}" text-anchor="middle" fill="#64748b" font-size="9" font-family="monospace">E{idx+1}</text>'
|
| 557 |
+
|
| 558 |
+
return f"""
|
| 559 |
+
<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:8px;padding:12px;margin-top:8px">
|
| 560 |
+
<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">📉 Training Loss Curve</div>
|
| 561 |
+
<svg width="{W}" height="{H}" style="display:block;max-width:100%">
|
| 562 |
+
<defs>
|
| 563 |
+
<linearGradient id="lcg" x1="0" y1="0" x2="{W}" y2="0" gradientUnits="userSpaceOnUse">
|
| 564 |
+
<stop offset="0%" stop-color="#64ffda"/><stop offset="100%" stop-color="#29b6f6"/>
|
| 565 |
+
</linearGradient>
|
| 566 |
+
</defs>
|
| 567 |
+
<rect width="{W}" height="{H}" fill="#0d1120"/>
|
| 568 |
+
{grid}
|
| 569 |
+
<polygon points="{fill_pts}" fill="rgba(100,255,218,0.08)"/>
|
| 570 |
+
<polyline points="{pts}" fill="none" stroke="url(#lcg)" stroke-width="2.5" stroke-linejoin="round"/>
|
| 571 |
+
{xlabels}
|
| 572 |
+
</svg>
|
| 573 |
+
</div>"""
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
# ─────────────────────────────────────────────
|
| 577 |
+
# Gradio callbacks
|
| 578 |
+
# ─────────────────────────────────────────────
|
| 579 |
+
|
| 580 |
+
def do_train(epochs_str, progress=gr.Progress()):
|
| 581 |
+
global MODEL, LOSS_HISTORY, IS_TRAINED
|
| 582 |
+
try:
|
| 583 |
+
epochs = int(epochs_str)
|
| 584 |
+
except:
|
| 585 |
+
epochs = 30
|
| 586 |
+
|
| 587 |
+
losses = []
|
| 588 |
+
def cb(ep, total, loss):
|
| 589 |
+
losses.append(loss)
|
| 590 |
+
progress((ep/total), desc=f"Epoch {ep}/{total} — loss {loss:.4f}")
|
| 591 |
+
|
| 592 |
+
MODEL, LOSS_HISTORY = run_training(epochs=epochs, device=DEVICE, progress_cb=cb)
|
| 593 |
+
IS_TRAINED = True
|
| 594 |
+
|
| 595 |
+
chart_html = _loss_svg(LOSS_HISTORY)
|
| 596 |
+
return (
|
| 597 |
+
f"✅ Trained {epochs} epochs. Final loss: {LOSS_HISTORY[-1]:.4f}",
|
| 598 |
+
chart_html
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def do_training_viz(en_sentence, bn_sentence):
|
| 603 |
+
model = get_or_init_model()
|
| 604 |
+
if not en_sentence.strip():
|
| 605 |
+
return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
|
| 606 |
+
if not bn_sentence.strip():
|
| 607 |
+
bn_sentence = "আমি তোমাকে ভালোবাসি"
|
| 608 |
+
|
| 609 |
+
result = visualize_training_step(model, en_sentence.strip(), bn_sentence.strip(), DEVICE)
|
| 610 |
+
|
| 611 |
+
# Attention heatmap (cross-attn layer 0, head 0)
|
| 612 |
+
meta = result.get("meta", {})
|
| 613 |
+
attn_html = ""
|
| 614 |
+
src_tokens = result.get("src_tokens", [])
|
| 615 |
+
tgt_tokens = result.get("tgt_tokens", [])
|
| 616 |
+
|
| 617 |
+
result_banner = f"""
|
| 618 |
+
<div class="result-banner">
|
| 619 |
+
<div class="result-label">English Input</div>
|
| 620 |
+
<div class="result-en">"{en_sentence}"</div>
|
| 621 |
+
<div class="result-label" style="margin-top:8px">Bengali (Teacher-forced)</div>
|
| 622 |
+
<div class="result-bn">{bn_sentence}</div>
|
| 623 |
+
<div style="margin-top:8px;color:var(--loss);font-size:13px">
|
| 624 |
+
📉 Loss: <strong>{result['loss']:.4f}</strong>
|
| 625 |
+
</div>
|
| 626 |
+
</div>"""
|
| 627 |
+
|
| 628 |
+
calc_html = f"""
|
| 629 |
+
<div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
|
| 630 |
+
<button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
|
| 631 |
+
<button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
|
| 632 |
+
{"".join(f'<button data-ga="filter:{cat}" style="background:var(--card);color:var(--cat-{cat},var(--text));border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">{cat}</button>' for cat in ['embed','pe','attn','ffn','norm','loss','cross','mask'])}
|
| 633 |
+
<button data-ga="filter:" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:10px">show all</button>
|
| 634 |
+
</div>
|
| 635 |
+
<div class="log-container">
|
| 636 |
+
{calc_log_to_html(result.get('calc_log', []))}
|
| 637 |
+
</div>"""
|
| 638 |
+
|
| 639 |
+
return result_banner, calc_html, attn_html
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
def do_inference_viz(en_sentence, decode_method):
|
| 643 |
+
model = get_or_init_model()
|
| 644 |
+
if not en_sentence.strip():
|
| 645 |
+
return "<p style='color:red'>Please enter an English sentence.</p>", "", ""
|
| 646 |
+
|
| 647 |
+
result = visualize_inference(model, en_sentence.strip(), DEVICE, decode_method)
|
| 648 |
+
|
| 649 |
+
result_banner = f"""
|
| 650 |
+
<div class="result-banner">
|
| 651 |
+
<div class="result-label">English Input</div>
|
| 652 |
+
<div class="result-en">"{en_sentence}"</div>
|
| 653 |
+
<div class="result-label" style="margin-top:8px">Bengali Translation ({decode_method})</div>
|
| 654 |
+
<div class="result-bn">{result['translation'] or '(no output)'}</div>
|
| 655 |
+
<div style="margin-top:6px;color:var(--muted);font-size:11px">
|
| 656 |
+
Tokens: {' → '.join(result['output_tokens'])}
|
| 657 |
+
</div>
|
| 658 |
+
</div>"""
|
| 659 |
+
|
| 660 |
+
decode_html = decode_steps_html(result.get("step_logs", []), result.get("src_tokens", []))
|
| 661 |
+
|
| 662 |
+
calc_html = f"""
|
| 663 |
+
<div style="margin-bottom:8px;display:flex;gap:6px;flex-wrap:wrap">
|
| 664 |
+
<button data-ga="expand" style="background:var(--card);color:var(--accent);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Expand All</button>
|
| 665 |
+
<button data-ga="collapse" style="background:var(--card);color:var(--muted);border:1px solid var(--border);padding:3px 10px;border-radius:4px;cursor:pointer;font-size:11px">Collapse All</button>
|
| 666 |
+
</div>
|
| 667 |
+
<div class="log-container">
|
| 668 |
+
{calc_log_to_html(result.get('calc_log', []))}
|
| 669 |
+
</div>"""
|
| 670 |
+
|
| 671 |
+
return result_banner + decode_html, calc_html, ""
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
# ─────────────────────────────────────────────
|
| 675 |
+
# Build UI
|
| 676 |
+
# ─────────────────────────────────────────────
|
| 677 |
+
|
| 678 |
+
def build_ui():
|
| 679 |
+
_theme = gr.themes.Base(primary_hue="teal", secondary_hue="purple", neutral_hue="slate")
|
| 680 |
+
with gr.Blocks(
|
| 681 |
+
title="Transformer Visualizer — EN→BN",
|
| 682 |
+
css=CUSTOM_CSS,
|
| 683 |
+
js=f"() => {{ {CUSTOM_JS} }}",
|
| 684 |
+
theme=_theme,
|
| 685 |
+
) as demo:
|
| 686 |
+
|
| 687 |
+
gr.HTML("""
|
| 688 |
+
<div style="text-align:center;padding:24px 0 12px;border-bottom:1px solid #1e2d45;margin-bottom:16px">
|
| 689 |
+
<div style="font-family:'Syne',sans-serif;font-size:28px;font-weight:800;
|
| 690 |
+
background:linear-gradient(135deg,#64ffda,#29b6f6,#ce93d8);
|
| 691 |
+
-webkit-background-clip:text;-webkit-text-fill-color:transparent;letter-spacing:2px">
|
| 692 |
+
TRANSFORMER VISUALIZER
|
| 693 |
+
</div>
|
| 694 |
+
<div style="color:#64748b;font-size:12px;letter-spacing:3px;margin-top:4px;font-family:'JetBrains Mono',monospace">
|
| 695 |
+
ENGLISH → BENGALI · EVERY CALCULATION EXPOSED
|
| 696 |
+
</div>
|
| 697 |
+
</div>
|
| 698 |
+
""")
|
| 699 |
+
|
| 700 |
+
with gr.Tabs():
|
| 701 |
+
|
| 702 |
+
# ── TAB 0: Architecture ──────────────────
|
| 703 |
+
with gr.Tab("🏗️ Architecture"):
|
| 704 |
+
gr.HTML(ARCH_SVG)
|
| 705 |
+
gr.HTML("""
|
| 706 |
+
<div style="display:grid;grid-template-columns:1fr 1fr;gap:12px;margin-top:12px">
|
| 707 |
+
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
|
| 708 |
+
<div style="color:#4fc3f7;font-weight:700;margin-bottom:8px">📌 Encoder Flow</div>
|
| 709 |
+
<div style="color:#94a3b8;font-size:12px;line-height:1.8">
|
| 710 |
+
1. English tokens → Embedding (d_model=64)<br>
|
| 711 |
+
2. + Positional Encoding (sin/cos)<br>
|
| 712 |
+
3. Multi-Head Self-Attention (4 heads)<br>
|
| 713 |
+
4. Add & LayerNorm<br>
|
| 714 |
+
5. Feed-Forward (64→128→64)<br>
|
| 715 |
+
6. Add & LayerNorm<br>
|
| 716 |
+
7. Repeat × 2 layers<br>
|
| 717 |
+
8. Output K, V for decoder
|
| 718 |
+
</div>
|
| 719 |
+
</div>
|
| 720 |
+
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:14px">
|
| 721 |
+
<div style="color:#ce93d8;font-weight:700;margin-bottom:8px">📌 Decoder Flow</div>
|
| 722 |
+
<div style="color:#94a3b8;font-size:12px;line-height:1.8">
|
| 723 |
+
1. Bengali tokens → Embedding<br>
|
| 724 |
+
2. + Positional Encoding<br>
|
| 725 |
+
3. Masked MHA (future tokens blocked)<br>
|
| 726 |
+
4. Add & LayerNorm<br>
|
| 727 |
+
5. Cross-Attention: Q←decoder, K,V←encoder<br>
|
| 728 |
+
6. Add & LayerNorm<br>
|
| 729 |
+
7. Feed-Forward<br>
|
| 730 |
+
8. Linear → Softmax → Bengali token
|
| 731 |
+
</div>
|
| 732 |
+
</div>
|
| 733 |
+
</div>
|
| 734 |
+
""")
|
| 735 |
+
|
| 736 |
+
# ── TAB 1: Train ─────────────────────────
|
| 737 |
+
with gr.Tab("🏋️ Train Model"):
|
| 738 |
+
with gr.Row():
|
| 739 |
+
with gr.Column(scale=1):
|
| 740 |
+
gr.HTML('<div style="color:#64ffda;font-size:13px;font-weight:700;margin-bottom:8px">Quick Train</div>')
|
| 741 |
+
epochs_in = gr.Textbox(value="50", label="Epochs", max_lines=1)
|
| 742 |
+
train_btn = gr.Button("▶ Train on 30 parallel sentences", variant="primary")
|
| 743 |
+
train_status = gr.HTML()
|
| 744 |
+
with gr.Column(scale=2):
|
| 745 |
+
loss_chart = gr.HTML()
|
| 746 |
+
train_btn.click(do_train, inputs=[epochs_in], outputs=[train_status, loss_chart])
|
| 747 |
+
|
| 748 |
+
# ── TAB 2: Training Step Viz ──────────────
|
| 749 |
+
with gr.Tab("🔬 Training Step"):
|
| 750 |
+
gr.HTML('<div style="color:#ef9a9a;font-size:12px;margin-bottom:12px">📚 Shows <strong>teacher-forcing</strong>: ground-truth Bengali tokens are fed to decoder, loss + gradients computed.</div>')
|
| 751 |
+
with gr.Row():
|
| 752 |
+
en_in_t = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
|
| 753 |
+
bn_in_t = gr.Textbox(label="Bengali (ground truth)", placeholder="আমি তোমাকে ভালোবাসি", value="আমি তোমাকে ভালোবাসি")
|
| 754 |
+
run_train_viz = gr.Button("🔬 Run Training Step & Show All Calculations", variant="primary")
|
| 755 |
+
result_html_t = gr.HTML()
|
| 756 |
+
with gr.Row():
|
| 757 |
+
with gr.Column(scale=2):
|
| 758 |
+
calc_html_t = gr.HTML()
|
| 759 |
+
with gr.Column(scale=1):
|
| 760 |
+
attn_html_t = gr.HTML()
|
| 761 |
+
run_train_viz.click(do_training_viz,
|
| 762 |
+
inputs=[en_in_t, bn_in_t],
|
| 763 |
+
outputs=[result_html_t, calc_html_t, attn_html_t])
|
| 764 |
+
|
| 765 |
+
# ── TAB 3: Inference Viz ──────────────────
|
| 766 |
+
with gr.Tab("⚡ Inference"):
|
| 767 |
+
gr.HTML('<div style="color:#80cbc4;font-size:12px;margin-bottom:12px">🤖 Shows <strong>auto-regressive decoding</strong>: model generates Bengali token by token, no ground truth needed.</div>')
|
| 768 |
+
with gr.Row():
|
| 769 |
+
en_in_i = gr.Textbox(label="English Sentence", placeholder="i love you", value="i love you")
|
| 770 |
+
decode_radio = gr.Radio(["greedy", "beam"], value="greedy", label="Decode Method")
|
| 771 |
+
run_infer = gr.Button("⚡ Translate & Show All Calculations", variant="primary")
|
| 772 |
+
result_html_i = gr.HTML()
|
| 773 |
+
with gr.Row():
|
| 774 |
+
with gr.Column(scale=2):
|
| 775 |
+
calc_html_i = gr.HTML()
|
| 776 |
+
with gr.Column(scale=1):
|
| 777 |
+
attn_html_i = gr.HTML()
|
| 778 |
+
run_infer.click(do_inference_viz,
|
| 779 |
+
inputs=[en_in_i, decode_radio],
|
| 780 |
+
outputs=[result_html_i, calc_html_i, attn_html_i])
|
| 781 |
+
|
| 782 |
+
# ── TAB 4: Examples ──────────────────────
|
| 783 |
+
with gr.Tab("📖 Examples"):
|
| 784 |
+
gr.HTML("""
|
| 785 |
+
<div style="background:#141c2e;border:1px solid #1e2d45;border-radius:8px;padding:16px">
|
| 786 |
+
<div style="color:#64ffda;font-weight:700;margin-bottom:12px">Try these sentences:</div>
|
| 787 |
+
<div style="display:grid;grid-template-columns:1fr 1fr;gap:8px">
|
| 788 |
+
""" + "".join(
|
| 789 |
+
f'<div style="background:#0d1120;border:1px solid #1e2d45;border-radius:6px;padding:8px">'
|
| 790 |
+
f'<div style="color:#4fc3f7;font-size:12px">{en}</div>'
|
| 791 |
+
f'<div style="color:#ce93d8;font-size:13px;font-weight:600">{bn}</div>'
|
| 792 |
+
f'</div>'
|
| 793 |
+
for en, bn in PARALLEL_DATA[:12]
|
| 794 |
+
) + "</div></div>")
|
| 795 |
+
|
| 796 |
+
return demo
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
demo = build_ui()
|
| 800 |
+
demo.launch(server_name="0.0.0.0")
|
inference.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py
|
| 3 |
+
Inference (translation) for English→Bengali with full calculation logging.
|
| 4 |
+
Supports greedy decoding and beam search, showing every step.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
import math
|
| 11 |
+
from typing import Dict, List, Tuple, Optional
|
| 12 |
+
|
| 13 |
+
from transformer import Transformer, CalcLog
|
| 14 |
+
from vocab import get_vocabs, PAD_IDX, BOS_IDX, EOS_IDX
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ─────────────────────────────────────────────
|
| 18 |
+
# Greedy decoding with full logging
|
| 19 |
+
# ─────────────────────────────────────────────
|
| 20 |
+
|
| 21 |
+
def greedy_decode(
|
| 22 |
+
model: Transformer,
|
| 23 |
+
src: torch.Tensor,
|
| 24 |
+
max_len: int = 20,
|
| 25 |
+
device: str = "cpu",
|
| 26 |
+
log: Optional[CalcLog] = None,
|
| 27 |
+
) -> Tuple[List[int], List[Dict]]:
|
| 28 |
+
model.eval()
|
| 29 |
+
src_v, tgt_v = get_vocabs()
|
| 30 |
+
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
src_mask = model.make_src_mask(src)
|
| 33 |
+
|
| 34 |
+
# ── Encode once ──────────────────────
|
| 35 |
+
src_emb = model.src_embed(src) * math.sqrt(model.d_model)
|
| 36 |
+
enc_x = model.src_pe(src_emb, log=log)
|
| 37 |
+
|
| 38 |
+
enc_attn_weights = []
|
| 39 |
+
for i, layer in enumerate(model.encoder_layers):
|
| 40 |
+
enc_x, ew = layer(enc_x, src_mask=src_mask,
|
| 41 |
+
log=log if i == 0 else None, layer_idx=i)
|
| 42 |
+
enc_attn_weights.append(ew.cpu().numpy())
|
| 43 |
+
|
| 44 |
+
if log:
|
| 45 |
+
log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8],
|
| 46 |
+
note="Encoder finished. Output K,V will be reused for every decoder step.")
|
| 47 |
+
|
| 48 |
+
# ── Auto-regressive decode ────────────
|
| 49 |
+
generated = [BOS_IDX]
|
| 50 |
+
step_logs = []
|
| 51 |
+
|
| 52 |
+
for step in range(max_len):
|
| 53 |
+
tgt_so_far = torch.tensor([generated], dtype=torch.long, device=device)
|
| 54 |
+
tgt_mask = model.make_tgt_mask(tgt_so_far)
|
| 55 |
+
|
| 56 |
+
tgt_emb = model.tgt_embed(tgt_so_far) * math.sqrt(model.d_model)
|
| 57 |
+
dec_x = model.tgt_pe(tgt_emb)
|
| 58 |
+
|
| 59 |
+
step_dec_cross = []
|
| 60 |
+
for i, layer in enumerate(model.decoder_layers):
|
| 61 |
+
do_log = (log is not None) and (step < 3) and (i == 0)
|
| 62 |
+
if do_log:
|
| 63 |
+
log.log(f"INFERENCE_step{step}_dec_input", dec_x[0, :, :8],
|
| 64 |
+
note=f"Decoder input at step {step}: tokens so far = "
|
| 65 |
+
f"{tgt_v.tokens(generated)}")
|
| 66 |
+
dec_x, mw, cw = layer(
|
| 67 |
+
dec_x, enc_x,
|
| 68 |
+
tgt_mask=tgt_mask, src_mask=src_mask,
|
| 69 |
+
log=log if do_log else None,
|
| 70 |
+
layer_idx=i,
|
| 71 |
+
)
|
| 72 |
+
step_dec_cross.append(cw.cpu().numpy())
|
| 73 |
+
|
| 74 |
+
# Only look at last position
|
| 75 |
+
last_logits = model.output_linear(dec_x[:, -1, :]) # (1, V)
|
| 76 |
+
probs = F.softmax(last_logits, dim=-1)[0]
|
| 77 |
+
|
| 78 |
+
# Top-5 predictions
|
| 79 |
+
top5_probs, top5_ids = probs.topk(5)
|
| 80 |
+
top5 = [
|
| 81 |
+
{"token": tgt_v.idx2token.get(idx.item(), "?"),
|
| 82 |
+
"id": idx.item(),
|
| 83 |
+
"prob": round(prob.item(), 4)}
|
| 84 |
+
for prob, idx in zip(top5_probs, top5_ids)
|
| 85 |
+
]
|
| 86 |
+
|
| 87 |
+
# Greedy: pick highest
|
| 88 |
+
next_token = top5_ids[0].item()
|
| 89 |
+
|
| 90 |
+
step_info = {
|
| 91 |
+
"step": step,
|
| 92 |
+
"tokens_so_far": tgt_v.tokens(generated),
|
| 93 |
+
"top5": top5,
|
| 94 |
+
"chosen_token": tgt_v.idx2token.get(next_token, "?"),
|
| 95 |
+
"chosen_id": next_token,
|
| 96 |
+
"chosen_prob": round(top5_probs[0].item(), 4),
|
| 97 |
+
"cross_attn": step_dec_cross[0][0].tolist()
|
| 98 |
+
if step_dec_cross else None,
|
| 99 |
+
}
|
| 100 |
+
step_logs.append(step_info)
|
| 101 |
+
|
| 102 |
+
if log and step < 3:
|
| 103 |
+
log.log(f"INFERENCE_step{step}_top5", top5,
|
| 104 |
+
formula="P(next_token) = softmax(W_out · dec_out[-1])",
|
| 105 |
+
note=f"Step {step}: top-5 candidates. Chosen: {step_info['chosen_token']} ({step_info['chosen_prob']:.4f})")
|
| 106 |
+
|
| 107 |
+
generated.append(next_token)
|
| 108 |
+
if next_token == EOS_IDX:
|
| 109 |
+
break
|
| 110 |
+
|
| 111 |
+
return generated, step_logs
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ─────────────────────────────────────────────
|
| 115 |
+
# Beam search
|
| 116 |
+
# ─────────────────────────────────────────────
|
| 117 |
+
|
| 118 |
+
def beam_search(
|
| 119 |
+
model: Transformer,
|
| 120 |
+
src: torch.Tensor,
|
| 121 |
+
beam_size: int = 3,
|
| 122 |
+
max_len: int = 20,
|
| 123 |
+
device: str = "cpu",
|
| 124 |
+
log: Optional[CalcLog] = None,
|
| 125 |
+
) -> Tuple[List[int], List[Dict]]:
|
| 126 |
+
model.eval()
|
| 127 |
+
src_v, tgt_v = get_vocabs()
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
src_mask = model.make_src_mask(src)
|
| 131 |
+
|
| 132 |
+
# Encode
|
| 133 |
+
src_emb = model.src_embed(src) * math.sqrt(model.d_model)
|
| 134 |
+
enc_x = model.src_pe(src_emb)
|
| 135 |
+
for i, layer in enumerate(model.encoder_layers):
|
| 136 |
+
enc_x, _ = layer(enc_x, src_mask=src_mask)
|
| 137 |
+
|
| 138 |
+
# Beams: list of (score, token_ids)
|
| 139 |
+
beams = [(0.0, [BOS_IDX])]
|
| 140 |
+
completed = []
|
| 141 |
+
beam_step_logs = []
|
| 142 |
+
|
| 143 |
+
for step in range(max_len):
|
| 144 |
+
if not beams:
|
| 145 |
+
break
|
| 146 |
+
candidates = []
|
| 147 |
+
|
| 148 |
+
for beam_idx, (score, tokens) in enumerate(beams):
|
| 149 |
+
tgt_t = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 150 |
+
tgt_mask = model.make_tgt_mask(tgt_t)
|
| 151 |
+
tgt_emb = model.tgt_embed(tgt_t) * math.sqrt(model.d_model)
|
| 152 |
+
dec_x = model.tgt_pe(tgt_emb)
|
| 153 |
+
for i, layer in enumerate(model.decoder_layers):
|
| 154 |
+
dec_x, _, _ = layer(dec_x, enc_x,
|
| 155 |
+
tgt_mask=tgt_mask, src_mask=src_mask)
|
| 156 |
+
last_logits = model.output_linear(dec_x[:, -1, :])
|
| 157 |
+
log_probs = F.log_softmax(last_logits, dim=-1)[0]
|
| 158 |
+
top_lp, top_id = log_probs.topk(beam_size)
|
| 159 |
+
for lp, tid in zip(top_lp, top_id):
|
| 160 |
+
new_score = score + lp.item()
|
| 161 |
+
new_tokens = tokens + [tid.item()]
|
| 162 |
+
candidates.append((new_score, new_tokens))
|
| 163 |
+
|
| 164 |
+
# Keep top beam_size
|
| 165 |
+
candidates.sort(key=lambda x: x[0], reverse=True)
|
| 166 |
+
beams = []
|
| 167 |
+
step_info = {"step": step, "beams": []}
|
| 168 |
+
for sc, toks in candidates[:beam_size * 2]:
|
| 169 |
+
if toks[-1] == EOS_IDX:
|
| 170 |
+
completed.append((sc / len(toks), toks))
|
| 171 |
+
else:
|
| 172 |
+
beams.append((sc, toks))
|
| 173 |
+
step_info["beams"].append({
|
| 174 |
+
"score": round(sc, 4),
|
| 175 |
+
"tokens": tgt_v.tokens(toks),
|
| 176 |
+
"text": tgt_v.decode(toks),
|
| 177 |
+
})
|
| 178 |
+
if len(beams) == beam_size:
|
| 179 |
+
break
|
| 180 |
+
beam_step_logs.append(step_info)
|
| 181 |
+
|
| 182 |
+
if len(completed) >= beam_size:
|
| 183 |
+
break
|
| 184 |
+
|
| 185 |
+
if completed:
|
| 186 |
+
best = max(completed, key=lambda x: x[0])
|
| 187 |
+
return best[1], beam_step_logs
|
| 188 |
+
elif beams:
|
| 189 |
+
return beams[0][1] + [EOS_IDX], beam_step_logs
|
| 190 |
+
else:
|
| 191 |
+
return [BOS_IDX, EOS_IDX], beam_step_logs
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
# ─────────────────────────────────────────────
|
| 195 |
+
# Full inference pipeline with visualization
|
| 196 |
+
# ─────────────────────────────────────────────
|
| 197 |
+
|
| 198 |
+
def visualize_inference(
|
| 199 |
+
model: Transformer,
|
| 200 |
+
en_sentence: str,
|
| 201 |
+
device: str = "cpu",
|
| 202 |
+
decode_method: str = "greedy",
|
| 203 |
+
) -> Dict:
|
| 204 |
+
src_v, tgt_v = get_vocabs()
|
| 205 |
+
log = CalcLog()
|
| 206 |
+
|
| 207 |
+
src_ids = src_v.encode(en_sentence)
|
| 208 |
+
log.log("INFERENCE_TOKENIZATION", {
|
| 209 |
+
"sentence": en_sentence,
|
| 210 |
+
"tokens": en_sentence.lower().split(),
|
| 211 |
+
"ids": src_ids,
|
| 212 |
+
}, formula="word → vocab_id lookup",
|
| 213 |
+
note="No ground-truth Bengali needed — model generates from scratch")
|
| 214 |
+
|
| 215 |
+
src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 216 |
+
|
| 217 |
+
if decode_method == "beam":
|
| 218 |
+
output_ids, step_logs = beam_search(model, src, beam_size=3,
|
| 219 |
+
device=device, log=log)
|
| 220 |
+
log.log("BEAM_SEARCH_complete", {
|
| 221 |
+
"method": "beam search (beam=3)",
|
| 222 |
+
"note": "Explores multiple hypotheses simultaneously — generally better quality"
|
| 223 |
+
})
|
| 224 |
+
else:
|
| 225 |
+
output_ids, step_logs = greedy_decode(model, src, device=device, log=log)
|
| 226 |
+
log.log("GREEDY_complete", {
|
| 227 |
+
"method": "greedy decoding",
|
| 228 |
+
"note": "Always picks highest probability token — fast but can miss optimal sequences"
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
translation = tgt_v.decode(output_ids)
|
| 232 |
+
output_tokens = tgt_v.tokens(output_ids)
|
| 233 |
+
|
| 234 |
+
log.log("FINAL_TRANSLATION", {
|
| 235 |
+
"input": en_sentence,
|
| 236 |
+
"output_ids": output_ids,
|
| 237 |
+
"output_tokens": output_tokens,
|
| 238 |
+
"translation": translation,
|
| 239 |
+
}, note="Complete English→Bengali translation")
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"en_sentence": en_sentence,
|
| 243 |
+
"translation": translation,
|
| 244 |
+
"output_tokens": output_tokens,
|
| 245 |
+
"output_ids": output_ids,
|
| 246 |
+
"src_tokens": src_v.tokens(src_ids),
|
| 247 |
+
"step_logs": step_logs,
|
| 248 |
+
"calc_log": log.to_dict(),
|
| 249 |
+
"decode_method": decode_method,
|
| 250 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
numpy>=1.24.0
|
training.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
training.py
|
| 3 |
+
Training loop for English→Bengali transformer with full calculation capture.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
import numpy as np
|
| 10 |
+
import math
|
| 11 |
+
from typing import Dict, List, Tuple, Optional
|
| 12 |
+
|
| 13 |
+
from transformer import Transformer, CalcLog
|
| 14 |
+
from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX, BOS_IDX, EOS_IDX
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ─────────────────────────────────────────────
|
| 18 |
+
# Data helpers
|
| 19 |
+
# ─────────────────────────────────────────────
|
| 20 |
+
|
| 21 |
+
def collate_batch(pairs: List[Tuple[str, str]], src_v, tgt_v, device: str = "cpu"):
|
| 22 |
+
src_seqs, tgt_seqs = [], []
|
| 23 |
+
for en, bn in pairs:
|
| 24 |
+
src_seqs.append(src_v.encode(en))
|
| 25 |
+
tgt_seqs.append(tgt_v.encode(bn))
|
| 26 |
+
|
| 27 |
+
def pad(seqs):
|
| 28 |
+
max_len = max(len(s) for s in seqs)
|
| 29 |
+
padded = [s + [PAD_IDX] * (max_len - len(s)) for s in seqs]
|
| 30 |
+
return torch.tensor(padded, dtype=torch.long, device=device)
|
| 31 |
+
|
| 32 |
+
return pad(src_seqs), pad(tgt_seqs)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─────────────────────────────────────────────
|
| 36 |
+
# Label-smoothed cross-entropy
|
| 37 |
+
# ─────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
class LabelSmoothingLoss(nn.Module):
|
| 40 |
+
def __init__(self, vocab_size: int, pad_idx: int, smoothing: float = 0.1):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.vocab_size = vocab_size
|
| 43 |
+
self.pad_idx = pad_idx
|
| 44 |
+
self.smoothing = smoothing
|
| 45 |
+
self.confidence = 1.0 - smoothing
|
| 46 |
+
|
| 47 |
+
def forward(self, logits: torch.Tensor, target: torch.Tensor,
|
| 48 |
+
log: Optional[CalcLog] = None) -> torch.Tensor:
|
| 49 |
+
B, T, V = logits.shape
|
| 50 |
+
logits_flat = logits.reshape(-1, V)
|
| 51 |
+
target_flat = target.reshape(-1)
|
| 52 |
+
|
| 53 |
+
log_probs = torch.log_softmax(logits_flat, dim=-1)
|
| 54 |
+
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
smooth_dist = torch.full_like(log_probs, self.smoothing / (V - 2))
|
| 57 |
+
smooth_dist.scatter_(1, target_flat.unsqueeze(1), self.confidence)
|
| 58 |
+
smooth_dist[:, self.pad_idx] = 0
|
| 59 |
+
mask = (target_flat == self.pad_idx)
|
| 60 |
+
smooth_dist[mask] = 0
|
| 61 |
+
|
| 62 |
+
loss = -(smooth_dist * log_probs).sum(dim=-1)
|
| 63 |
+
non_pad = (~mask).sum()
|
| 64 |
+
loss = loss.sum() / non_pad.clamp(min=1)
|
| 65 |
+
|
| 66 |
+
if log:
|
| 67 |
+
probs_sample = torch.exp(log_probs[:4])
|
| 68 |
+
log.log("LOSS_log_probs_sample", probs_sample,
|
| 69 |
+
formula="log P(token) = log_softmax(logits)",
|
| 70 |
+
note="Softmax probabilities for first 4 target positions")
|
| 71 |
+
log.log("LOSS_smooth_dist_sample", smooth_dist[:4],
|
| 72 |
+
formula=f"smooth: correct={self.confidence:.2f}, others={self.smoothing/(V-2):.5f}",
|
| 73 |
+
note="Label-smoothed target distribution")
|
| 74 |
+
log.log("LOSS_value", loss.item(),
|
| 75 |
+
formula="L = -Σ smooth_dist · log_probs / n_tokens",
|
| 76 |
+
note=f"Label-smoothed cross-entropy loss = {loss.item():.4f}")
|
| 77 |
+
|
| 78 |
+
return loss
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ─────────────────────────────────────────────
|
| 82 |
+
# Build model
|
| 83 |
+
# ─────────────────────────────────────────────
|
| 84 |
+
|
| 85 |
+
def build_model(src_vocab_size: int, tgt_vocab_size: int,
|
| 86 |
+
device: str = "cpu") -> Transformer:
|
| 87 |
+
model = Transformer(
|
| 88 |
+
src_vocab_size=src_vocab_size,
|
| 89 |
+
tgt_vocab_size=tgt_vocab_size,
|
| 90 |
+
d_model=64,
|
| 91 |
+
num_heads=4,
|
| 92 |
+
num_layers=2,
|
| 93 |
+
d_ff=128,
|
| 94 |
+
max_len=32,
|
| 95 |
+
dropout=0.1,
|
| 96 |
+
pad_idx=PAD_IDX,
|
| 97 |
+
).to(device)
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ─────────────────────────────────────────────
|
| 102 |
+
# Single training step (with full logging)
|
| 103 |
+
# ─────────────────────────────────────────────
|
| 104 |
+
|
| 105 |
+
def training_step(
|
| 106 |
+
model: Transformer,
|
| 107 |
+
src: torch.Tensor,
|
| 108 |
+
tgt: torch.Tensor,
|
| 109 |
+
criterion: LabelSmoothingLoss,
|
| 110 |
+
optimizer: optim.Optimizer,
|
| 111 |
+
log: CalcLog,
|
| 112 |
+
step_num: int = 0,
|
| 113 |
+
) -> Dict:
|
| 114 |
+
model.train()
|
| 115 |
+
log.clear()
|
| 116 |
+
|
| 117 |
+
# Teacher forcing: decoder input = [BOS, token_1, ..., token_{T-1}]
|
| 118 |
+
tgt_input = tgt[:, :-1]
|
| 119 |
+
tgt_target = tgt[:, 1:]
|
| 120 |
+
|
| 121 |
+
log.log("TRAINING_SETUP", {
|
| 122 |
+
"mode": "TRAINING",
|
| 123 |
+
"step": step_num,
|
| 124 |
+
"src_shape": list(src.shape),
|
| 125 |
+
"tgt_input_shape": list(tgt_input.shape),
|
| 126 |
+
"tgt_target_shape": list(tgt_target.shape),
|
| 127 |
+
}, formula="Teacher Forcing: feed ground-truth Bengali tokens as decoder input",
|
| 128 |
+
note="During training, decoder sees actual Bengali tokens (not its own predictions)")
|
| 129 |
+
|
| 130 |
+
log.log("SRC_sentence_ids", src[0].tolist(),
|
| 131 |
+
note="Source (English) token IDs fed to encoder")
|
| 132 |
+
log.log("TGT_input_ids", tgt_input[0].tolist(),
|
| 133 |
+
note="Target input to decoder (shifted right — starts with <BOS>)")
|
| 134 |
+
log.log("TGT_target_ids", tgt_target[0].tolist(),
|
| 135 |
+
note="What decoder must predict (shifted left — ends with <EOS>)")
|
| 136 |
+
|
| 137 |
+
# Forward
|
| 138 |
+
logits, meta = model(src, tgt_input, log=log)
|
| 139 |
+
|
| 140 |
+
# Loss
|
| 141 |
+
loss = criterion(logits, tgt_target, log=log)
|
| 142 |
+
|
| 143 |
+
log.log("LOSS_final", loss.item(),
|
| 144 |
+
formula="Total loss = label-smoothed cross-entropy averaged over tokens",
|
| 145 |
+
note=f"Loss = {loss.item():.4f} (lower = better prediction)")
|
| 146 |
+
|
| 147 |
+
# Backward
|
| 148 |
+
optimizer.zero_grad()
|
| 149 |
+
loss.backward()
|
| 150 |
+
|
| 151 |
+
# Gradient stats
|
| 152 |
+
grad_norms = {}
|
| 153 |
+
for name, param in model.named_parameters():
|
| 154 |
+
if param.grad is not None:
|
| 155 |
+
gn = param.grad.norm().item()
|
| 156 |
+
grad_norms[name] = round(gn, 6)
|
| 157 |
+
|
| 158 |
+
log.log("GRADIENTS_norm_sample", dict(list(grad_norms.items())[:8]),
|
| 159 |
+
formula="∂L/∂W via backpropagation (chain rule)",
|
| 160 |
+
note="Gradient norms for first 8 parameter tensors")
|
| 161 |
+
|
| 162 |
+
# Gradient clipping
|
| 163 |
+
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 164 |
+
optimizer.step()
|
| 165 |
+
|
| 166 |
+
log.log("OPTIMIZER_step", {
|
| 167 |
+
"algorithm": "Adam",
|
| 168 |
+
"lr": optimizer.param_groups[0]["lr"],
|
| 169 |
+
"note": "W = W - lr × (m̂ / (√v̂ + ε)) (Adam update rule)",
|
| 170 |
+
}, formula="Adam: adaptive learning rate with momentum",
|
| 171 |
+
note="Weights updated — model slightly improved")
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"loss": loss.item(),
|
| 175 |
+
"calc_log": log.to_dict(),
|
| 176 |
+
"meta": {k: v.tolist() if hasattr(v, "tolist") else v for k, v in meta.items()
|
| 177 |
+
if k != "enc_attn"},
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ─────────────────────────────────────────────
|
| 182 |
+
# Full training run (quick demo)
|
| 183 |
+
# ─────────────────────────────────────────────
|
| 184 |
+
|
| 185 |
+
def run_training(
|
| 186 |
+
epochs: int = 30,
|
| 187 |
+
device: str = "cpu",
|
| 188 |
+
progress_cb=None,
|
| 189 |
+
) -> Tuple[Transformer, List[float]]:
|
| 190 |
+
src_v, tgt_v = get_vocabs()
|
| 191 |
+
model = build_model(len(src_v), len(tgt_v), device)
|
| 192 |
+
criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX, smoothing=0.1)
|
| 193 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9)
|
| 194 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)
|
| 195 |
+
|
| 196 |
+
losses = []
|
| 197 |
+
src_batch, tgt_batch = collate_batch(PARALLEL_DATA, src_v, tgt_v, device)
|
| 198 |
+
|
| 199 |
+
for epoch in range(1, epochs + 1):
|
| 200 |
+
model.train()
|
| 201 |
+
tgt_input = tgt_batch[:, :-1]
|
| 202 |
+
tgt_target = tgt_batch[:, 1:]
|
| 203 |
+
logits, _ = model(src_batch, tgt_input, log=None)
|
| 204 |
+
loss = criterion(logits, tgt_target)
|
| 205 |
+
optimizer.zero_grad()
|
| 206 |
+
loss.backward()
|
| 207 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 208 |
+
optimizer.step()
|
| 209 |
+
scheduler.step(loss.item())
|
| 210 |
+
losses.append(loss.item())
|
| 211 |
+
if progress_cb:
|
| 212 |
+
progress_cb(epoch, epochs, loss.item())
|
| 213 |
+
|
| 214 |
+
return model, losses
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# ─────────────────────────────────────────────
|
| 218 |
+
# Single-sample step for visualization
|
| 219 |
+
# ─────────────────────────────────────────────
|
| 220 |
+
|
| 221 |
+
def visualize_training_step(
|
| 222 |
+
model: Transformer,
|
| 223 |
+
en_sentence: str,
|
| 224 |
+
bn_sentence: str,
|
| 225 |
+
device: str = "cpu",
|
| 226 |
+
) -> Dict:
|
| 227 |
+
src_v, tgt_v = get_vocabs()
|
| 228 |
+
log = CalcLog()
|
| 229 |
+
|
| 230 |
+
src_ids = src_v.encode(en_sentence)
|
| 231 |
+
tgt_ids = tgt_v.encode(bn_sentence)
|
| 232 |
+
|
| 233 |
+
log.log("TOKENIZATION_EN", {
|
| 234 |
+
"sentence": en_sentence,
|
| 235 |
+
"tokens": en_sentence.lower().split(),
|
| 236 |
+
"ids": src_ids,
|
| 237 |
+
"vocab_size": len(src_v),
|
| 238 |
+
}, formula="token_id = vocab[word]",
|
| 239 |
+
note="English → token IDs (BOS prepended, EOS appended)")
|
| 240 |
+
|
| 241 |
+
log.log("TOKENIZATION_BN", {
|
| 242 |
+
"sentence": bn_sentence,
|
| 243 |
+
"tokens": bn_sentence.split(),
|
| 244 |
+
"ids": tgt_ids,
|
| 245 |
+
"vocab_size": len(tgt_v),
|
| 246 |
+
}, note="Bengali → token IDs (teacher-forced during training)")
|
| 247 |
+
|
| 248 |
+
src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 249 |
+
tgt = torch.tensor([tgt_ids], dtype=torch.long, device=device)
|
| 250 |
+
|
| 251 |
+
criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX)
|
| 252 |
+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
|
| 253 |
+
|
| 254 |
+
result = training_step(model, src, tgt, criterion, optimizer, log)
|
| 255 |
+
|
| 256 |
+
src_v_obj, tgt_v_obj = get_vocabs()
|
| 257 |
+
result["src_tokens"] = src_v_obj.tokens(src_ids)
|
| 258 |
+
result["tgt_tokens"] = tgt_v_obj.tokens(tgt_ids)
|
| 259 |
+
result["en_sentence"] = en_sentence
|
| 260 |
+
result["bn_sentence"] = bn_sentence
|
| 261 |
+
return result
|
transformer.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
transformer.py
|
| 3 |
+
Full Transformer implementation for English → Bengali translation
|
| 4 |
+
with complete calculation tracking at every step.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
import numpy as np
|
| 11 |
+
import math
|
| 12 |
+
from typing import Optional, Tuple, Dict, List
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ─────────────────────────────────────────────
|
| 16 |
+
# Calculation Logger
|
| 17 |
+
# ─────────────────────────────────────────────
|
| 18 |
+
|
| 19 |
+
class CalcLog:
|
| 20 |
+
"""Captures every intermediate tensor for visualization."""
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.steps: List[Dict] = []
|
| 23 |
+
|
| 24 |
+
def log(self, name: str, data, formula: str = "", note: str = ""):
|
| 25 |
+
entry = {
|
| 26 |
+
"name": name,
|
| 27 |
+
"formula": formula,
|
| 28 |
+
"note": note,
|
| 29 |
+
"shape": None,
|
| 30 |
+
"value": None,
|
| 31 |
+
}
|
| 32 |
+
if isinstance(data, torch.Tensor):
|
| 33 |
+
entry["shape"] = list(data.shape)
|
| 34 |
+
entry["value"] = data.detach().cpu().numpy().tolist()
|
| 35 |
+
elif isinstance(data, np.ndarray):
|
| 36 |
+
entry["shape"] = list(data.shape)
|
| 37 |
+
entry["value"] = data.tolist()
|
| 38 |
+
else:
|
| 39 |
+
entry["value"] = data
|
| 40 |
+
self.steps.append(entry)
|
| 41 |
+
return data
|
| 42 |
+
|
| 43 |
+
def clear(self):
|
| 44 |
+
self.steps = []
|
| 45 |
+
|
| 46 |
+
def to_dict(self):
|
| 47 |
+
return self.steps
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ─────────────────────────────────────────────
|
| 51 |
+
# Positional Encoding
|
| 52 |
+
# ─────────────────────────────────────────────
|
| 53 |
+
|
| 54 |
+
class PositionalEncoding(nn.Module):
|
| 55 |
+
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.d_model = d_model
|
| 58 |
+
self.dropout = nn.Dropout(dropout)
|
| 59 |
+
|
| 60 |
+
pe = torch.zeros(max_len, d_model)
|
| 61 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 62 |
+
div_term = torch.exp(
|
| 63 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 64 |
+
)
|
| 65 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 66 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 67 |
+
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
|
| 68 |
+
|
| 69 |
+
def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None) -> torch.Tensor:
|
| 70 |
+
seq_len = x.size(1)
|
| 71 |
+
pe_slice = self.pe[:, :seq_len, :]
|
| 72 |
+
|
| 73 |
+
if log:
|
| 74 |
+
log.log("PE_matrix", pe_slice[0, :seq_len, :8],
|
| 75 |
+
formula="PE(pos,2i)=sin(pos/10000^(2i/d)), PE(pos,2i+1)=cos(...)",
|
| 76 |
+
note=f"Showing first 8 dims for {seq_len} positions")
|
| 77 |
+
log.log("Embedding_before_PE", x[0, :, :8],
|
| 78 |
+
note="Token embeddings (first 8 dims)")
|
| 79 |
+
|
| 80 |
+
x = x + pe_slice
|
| 81 |
+
if log:
|
| 82 |
+
log.log("Embedding_after_PE", x[0, :, :8],
|
| 83 |
+
formula="X = Embedding + PE",
|
| 84 |
+
note="After adding positional encoding")
|
| 85 |
+
return self.dropout(x)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ─────────────────────────────────────────────
|
| 89 |
+
# Scaled Dot-Product Attention
|
| 90 |
+
# ─────────────────────────────────────────────
|
| 91 |
+
|
| 92 |
+
def scaled_dot_product_attention(
|
| 93 |
+
Q: torch.Tensor,
|
| 94 |
+
K: torch.Tensor,
|
| 95 |
+
V: torch.Tensor,
|
| 96 |
+
mask: Optional[torch.Tensor] = None,
|
| 97 |
+
log: Optional[CalcLog] = None,
|
| 98 |
+
head_idx: int = 0,
|
| 99 |
+
layer_idx: int = 0,
|
| 100 |
+
attn_type: str = "self",
|
| 101 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 102 |
+
d_k = Q.size(-1)
|
| 103 |
+
prefix = f"L{layer_idx}_H{head_idx}_{attn_type}"
|
| 104 |
+
|
| 105 |
+
# Raw scores
|
| 106 |
+
scores = torch.matmul(Q, K.transpose(-2, -1))
|
| 107 |
+
if log:
|
| 108 |
+
log.log(f"{prefix}_Q", Q[0],
|
| 109 |
+
formula="Q = X · Wq",
|
| 110 |
+
note=f"Query matrix head {head_idx}")
|
| 111 |
+
log.log(f"{prefix}_K", K[0],
|
| 112 |
+
formula="K = X · Wk",
|
| 113 |
+
note=f"Key matrix head {head_idx}")
|
| 114 |
+
log.log(f"{prefix}_V", V[0],
|
| 115 |
+
formula="V = X · Wv",
|
| 116 |
+
note=f"Value matrix head {head_idx}")
|
| 117 |
+
log.log(f"{prefix}_QKt", scores[0],
|
| 118 |
+
formula="scores = Q · Kᵀ",
|
| 119 |
+
note=f"Raw attention scores (before scaling)")
|
| 120 |
+
|
| 121 |
+
# Scale
|
| 122 |
+
scale = math.sqrt(d_k)
|
| 123 |
+
scores = scores / scale
|
| 124 |
+
if log:
|
| 125 |
+
log.log(f"{prefix}_QKt_scaled", scores[0],
|
| 126 |
+
formula=f"scores = Q·Kᵀ / √{d_k} = Q·Kᵀ / {scale:.3f}",
|
| 127 |
+
note="Scaled scores — prevents vanishing gradients")
|
| 128 |
+
|
| 129 |
+
# Mask
|
| 130 |
+
# masks arrive as (B,1,1,T) or (B,1,T,T) from make_src/tgt_mask;
|
| 131 |
+
# scores here are 3-D (B,T_q,T_k) because we loop per-head,
|
| 132 |
+
# so squeeze the head dim to avoid (B,B,...) broadcasting.
|
| 133 |
+
if mask is not None:
|
| 134 |
+
if mask.dim() == 4:
|
| 135 |
+
mask = mask.squeeze(1) # (B,1,T,T) or (B,1,1,T) → (B,T,T) or (B,1,T)
|
| 136 |
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
| 137 |
+
if log:
|
| 138 |
+
log.log(f"{prefix}_mask", mask[0].float(),
|
| 139 |
+
formula="mask[i,j]=0 → score=-inf (future token blocked)",
|
| 140 |
+
note="Causal mask (training decoder) or padding mask")
|
| 141 |
+
log.log(f"{prefix}_scores_masked", scores[0],
|
| 142 |
+
note="Scores after masking (-inf will become 0 after softmax)")
|
| 143 |
+
|
| 144 |
+
# Softmax
|
| 145 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 146 |
+
# replace nan from -inf rows with 0 (edge case)
|
| 147 |
+
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
|
| 148 |
+
if log:
|
| 149 |
+
log.log(f"{prefix}_softmax", attn_weights[0],
|
| 150 |
+
formula="α = softmax(scores, dim=-1)",
|
| 151 |
+
note="Attention weights — each row sums to 1.0")
|
| 152 |
+
|
| 153 |
+
# Weighted sum
|
| 154 |
+
output = torch.matmul(attn_weights, V)
|
| 155 |
+
if log:
|
| 156 |
+
log.log(f"{prefix}_output", output[0],
|
| 157 |
+
formula="Attention = α · V",
|
| 158 |
+
note="Weighted sum of values")
|
| 159 |
+
|
| 160 |
+
return output, attn_weights
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# ─────────────────────────────────────────────
|
| 164 |
+
# Multi-Head Attention
|
| 165 |
+
# ─────────────────────────────────────────────
|
| 166 |
+
|
| 167 |
+
class MultiHeadAttention(nn.Module):
|
| 168 |
+
def __init__(self, d_model: int, num_heads: int):
|
| 169 |
+
super().__init__()
|
| 170 |
+
assert d_model % num_heads == 0
|
| 171 |
+
self.d_model = d_model
|
| 172 |
+
self.num_heads = num_heads
|
| 173 |
+
self.d_k = d_model // num_heads
|
| 174 |
+
|
| 175 |
+
self.W_q = nn.Linear(d_model, d_model, bias=False)
|
| 176 |
+
self.W_k = nn.Linear(d_model, d_model, bias=False)
|
| 177 |
+
self.W_v = nn.Linear(d_model, d_model, bias=False)
|
| 178 |
+
self.W_o = nn.Linear(d_model, d_model, bias=False)
|
| 179 |
+
|
| 180 |
+
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
B, T, D = x.shape
|
| 182 |
+
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
|
| 183 |
+
# → (B, num_heads, T, d_k)
|
| 184 |
+
|
| 185 |
+
def forward(
|
| 186 |
+
self,
|
| 187 |
+
query: torch.Tensor,
|
| 188 |
+
key: torch.Tensor,
|
| 189 |
+
value: torch.Tensor,
|
| 190 |
+
mask: Optional[torch.Tensor] = None,
|
| 191 |
+
log: Optional[CalcLog] = None,
|
| 192 |
+
layer_idx: int = 0,
|
| 193 |
+
attn_type: str = "self",
|
| 194 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 195 |
+
B = query.size(0)
|
| 196 |
+
prefix = f"L{layer_idx}_{attn_type}_MHA"
|
| 197 |
+
|
| 198 |
+
# Linear projections
|
| 199 |
+
Q = self.W_q(query)
|
| 200 |
+
K = self.W_k(key)
|
| 201 |
+
V = self.W_v(value)
|
| 202 |
+
|
| 203 |
+
if log:
|
| 204 |
+
log.log(f"{prefix}_Wq", self.W_q.weight[:4, :4],
|
| 205 |
+
formula="Wq shape: (d_model, d_model)",
|
| 206 |
+
note=f"Query weight matrix (first 4×4 shown)")
|
| 207 |
+
log.log(f"{prefix}_Q_full", Q[0, :, :8],
|
| 208 |
+
formula="Q = input · Wq",
|
| 209 |
+
note=f"Full Q projection (first 8 dims shown)")
|
| 210 |
+
|
| 211 |
+
# Split into heads
|
| 212 |
+
Q = self.split_heads(Q) # (B, h, T, d_k)
|
| 213 |
+
K = self.split_heads(K)
|
| 214 |
+
V = self.split_heads(V)
|
| 215 |
+
|
| 216 |
+
if log:
|
| 217 |
+
log.log(f"{prefix}_Q_head0", Q[0, 0, :, :],
|
| 218 |
+
formula=f"Split: (B,T,D) → (B,{self.num_heads},T,{self.d_k})",
|
| 219 |
+
note=f"Head 0 queries — d_k={self.d_k}")
|
| 220 |
+
|
| 221 |
+
# Per-head attention (log only first 2 heads to avoid bloat)
|
| 222 |
+
all_attn = []
|
| 223 |
+
all_weights = []
|
| 224 |
+
for h in range(self.num_heads):
|
| 225 |
+
h_log = log if h < 2 else None
|
| 226 |
+
out_h, w_h = scaled_dot_product_attention(
|
| 227 |
+
Q[:, h], K[:, h], V[:, h],
|
| 228 |
+
mask=mask,
|
| 229 |
+
log=h_log,
|
| 230 |
+
head_idx=h,
|
| 231 |
+
layer_idx=layer_idx,
|
| 232 |
+
attn_type=attn_type,
|
| 233 |
+
)
|
| 234 |
+
all_attn.append(out_h)
|
| 235 |
+
all_weights.append(w_h)
|
| 236 |
+
|
| 237 |
+
# Concat heads
|
| 238 |
+
concat = torch.stack(all_attn, dim=1) # (B, h, T, d_k)
|
| 239 |
+
concat = concat.transpose(1, 2).contiguous() # (B, T, h, d_k)
|
| 240 |
+
concat = concat.view(B, -1, self.d_model) # (B, T, D)
|
| 241 |
+
|
| 242 |
+
if log:
|
| 243 |
+
log.log(f"{prefix}_concat", concat[0, :, :8],
|
| 244 |
+
formula="concat = [head_1; head_2; ...; head_h]",
|
| 245 |
+
note=f"Concatenated heads (first 8 dims)")
|
| 246 |
+
|
| 247 |
+
# Final projection
|
| 248 |
+
output = self.W_o(concat)
|
| 249 |
+
if log:
|
| 250 |
+
log.log(f"{prefix}_output", output[0, :, :8],
|
| 251 |
+
formula="MHA_out = concat · Wo",
|
| 252 |
+
note="Final multi-head attention output")
|
| 253 |
+
|
| 254 |
+
# Stack all attention weights: (B, h, T_q, T_k)
|
| 255 |
+
attn_weights = torch.stack(all_weights, dim=1)
|
| 256 |
+
return output, attn_weights
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ─────────────────────────────────────────────
|
| 260 |
+
# Feed-Forward Network
|
| 261 |
+
# ────────────────────────────���────────────────
|
| 262 |
+
|
| 263 |
+
class FeedForward(nn.Module):
|
| 264 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.linear1 = nn.Linear(d_model, d_ff)
|
| 267 |
+
self.linear2 = nn.Linear(d_ff, d_model)
|
| 268 |
+
self.dropout = nn.Dropout(dropout)
|
| 269 |
+
|
| 270 |
+
def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None,
|
| 271 |
+
layer_idx: int = 0, loc: str = "enc") -> torch.Tensor:
|
| 272 |
+
prefix = f"L{layer_idx}_{loc}_FFN"
|
| 273 |
+
h = self.linear1(x)
|
| 274 |
+
if log:
|
| 275 |
+
log.log(f"{prefix}_linear1", h[0, :, :8],
|
| 276 |
+
formula="h = X · W1 + b1",
|
| 277 |
+
note=f"First linear (d_model→d_ff), showing first 8 dims")
|
| 278 |
+
h = F.relu(h)
|
| 279 |
+
if log:
|
| 280 |
+
log.log(f"{prefix}_relu", h[0, :, :8],
|
| 281 |
+
formula="h = ReLU(h) = max(0, h)",
|
| 282 |
+
note="Negative values zeroed out")
|
| 283 |
+
h = self.dropout(h)
|
| 284 |
+
out = self.linear2(h)
|
| 285 |
+
if log:
|
| 286 |
+
log.log(f"{prefix}_linear2", out[0, :, :8],
|
| 287 |
+
formula="out = h · W2 + b2",
|
| 288 |
+
note=f"Second linear (d_ff→d_model)")
|
| 289 |
+
return out
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ─────────────────────────────────────────────
|
| 293 |
+
# Layer Norm + Residual
|
| 294 |
+
# ─────────────────────────────────────────────
|
| 295 |
+
|
| 296 |
+
class AddNorm(nn.Module):
|
| 297 |
+
def __init__(self, d_model: int, eps: float = 1e-6):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.norm = nn.LayerNorm(d_model, eps=eps)
|
| 300 |
+
|
| 301 |
+
def forward(self, x: torch.Tensor, sublayer_out: torch.Tensor,
|
| 302 |
+
log: Optional[CalcLog] = None, tag: str = "") -> torch.Tensor:
|
| 303 |
+
residual = x + sublayer_out
|
| 304 |
+
out = self.norm(residual)
|
| 305 |
+
if log:
|
| 306 |
+
log.log(f"{tag}_residual", residual[0, :, :8],
|
| 307 |
+
formula="residual = x + sublayer(x)",
|
| 308 |
+
note="Residual (skip) connection")
|
| 309 |
+
log.log(f"{tag}_layernorm", out[0, :, :8],
|
| 310 |
+
formula="LayerNorm(x) = γ·(x−μ)/σ + β",
|
| 311 |
+
note="Layer normalization output")
|
| 312 |
+
return out
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ─────────────────────────────────────────────
|
| 316 |
+
# Encoder Layer
|
| 317 |
+
# ─────────────────────────────────────────────
|
| 318 |
+
|
| 319 |
+
class EncoderLayer(nn.Module):
|
| 320 |
+
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.self_attn = MultiHeadAttention(d_model, num_heads)
|
| 323 |
+
self.ffn = FeedForward(d_model, d_ff, dropout)
|
| 324 |
+
self.add_norm1 = AddNorm(d_model)
|
| 325 |
+
self.add_norm2 = AddNorm(d_model)
|
| 326 |
+
|
| 327 |
+
def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None,
|
| 328 |
+
log: Optional[CalcLog] = None, layer_idx: int = 0):
|
| 329 |
+
attn_out, attn_w = self.self_attn(
|
| 330 |
+
x, x, x, mask=src_mask, log=log,
|
| 331 |
+
layer_idx=layer_idx, attn_type="enc_self"
|
| 332 |
+
)
|
| 333 |
+
x = self.add_norm1(x, attn_out, log=log, tag=f"L{layer_idx}_enc_self")
|
| 334 |
+
ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="enc")
|
| 335 |
+
x = self.add_norm2(x, ffn_out, log=log, tag=f"L{layer_idx}_enc_ffn")
|
| 336 |
+
return x, attn_w
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# ─────────────────────────────────────────────
|
| 340 |
+
# Decoder Layer
|
| 341 |
+
# ─────────────────────────────────────────────
|
| 342 |
+
|
| 343 |
+
class DecoderLayer(nn.Module):
|
| 344 |
+
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
|
| 345 |
+
super().__init__()
|
| 346 |
+
self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
|
| 347 |
+
self.cross_attn = MultiHeadAttention(d_model, num_heads)
|
| 348 |
+
self.ffn = FeedForward(d_model, d_ff, dropout)
|
| 349 |
+
self.add_norm1 = AddNorm(d_model)
|
| 350 |
+
self.add_norm2 = AddNorm(d_model)
|
| 351 |
+
self.add_norm3 = AddNorm(d_model)
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
x: torch.Tensor,
|
| 356 |
+
enc_out: torch.Tensor,
|
| 357 |
+
tgt_mask: Optional[torch.Tensor] = None,
|
| 358 |
+
src_mask: Optional[torch.Tensor] = None,
|
| 359 |
+
log: Optional[CalcLog] = None,
|
| 360 |
+
layer_idx: int = 0,
|
| 361 |
+
):
|
| 362 |
+
# 1. Masked self-attention
|
| 363 |
+
m_attn_out, m_attn_w = self.masked_self_attn(
|
| 364 |
+
x, x, x, mask=tgt_mask, log=log,
|
| 365 |
+
layer_idx=layer_idx, attn_type="dec_masked"
|
| 366 |
+
)
|
| 367 |
+
x = self.add_norm1(x, m_attn_out, log=log, tag=f"L{layer_idx}_dec_masked")
|
| 368 |
+
|
| 369 |
+
# 2. Cross-attention: Q from decoder, K/V from encoder
|
| 370 |
+
if log:
|
| 371 |
+
log.log(f"L{layer_idx}_cross_Q_source", x[0, :, :8],
|
| 372 |
+
note="Cross-attn Q comes from DECODER (Bengali context)")
|
| 373 |
+
log.log(f"L{layer_idx}_cross_KV_source", enc_out[0, :, :8],
|
| 374 |
+
note="Cross-attn K,V come from ENCODER (English context)")
|
| 375 |
+
|
| 376 |
+
c_attn_out, c_attn_w = self.cross_attn(
|
| 377 |
+
query=x, key=enc_out, value=enc_out,
|
| 378 |
+
mask=src_mask, log=log,
|
| 379 |
+
layer_idx=layer_idx, attn_type="dec_cross"
|
| 380 |
+
)
|
| 381 |
+
x = self.add_norm2(x, c_attn_out, log=log, tag=f"L{layer_idx}_dec_cross")
|
| 382 |
+
|
| 383 |
+
# 3. FFN
|
| 384 |
+
ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="dec")
|
| 385 |
+
x = self.add_norm3(x, ffn_out, log=log, tag=f"L{layer_idx}_dec_ffn")
|
| 386 |
+
|
| 387 |
+
return x, m_attn_w, c_attn_w
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# ─────────────────────────────────────────────
|
| 391 |
+
# Full Transformer
|
| 392 |
+
# ─────────────────────────────────────────────
|
| 393 |
+
|
| 394 |
+
class Transformer(nn.Module):
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
src_vocab_size: int,
|
| 398 |
+
tgt_vocab_size: int,
|
| 399 |
+
d_model: int = 128,
|
| 400 |
+
num_heads: int = 4,
|
| 401 |
+
num_layers: int = 2,
|
| 402 |
+
d_ff: int = 256,
|
| 403 |
+
max_len: int = 64,
|
| 404 |
+
dropout: float = 0.1,
|
| 405 |
+
pad_idx: int = 0,
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
self.d_model = d_model
|
| 409 |
+
self.pad_idx = pad_idx
|
| 410 |
+
self.num_layers = num_layers
|
| 411 |
+
|
| 412 |
+
self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
|
| 413 |
+
self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
|
| 414 |
+
self.src_pe = PositionalEncoding(d_model, max_len, dropout)
|
| 415 |
+
self.tgt_pe = PositionalEncoding(d_model, max_len, dropout)
|
| 416 |
+
|
| 417 |
+
self.encoder_layers = nn.ModuleList(
|
| 418 |
+
[EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
|
| 419 |
+
)
|
| 420 |
+
self.decoder_layers = nn.ModuleList(
|
| 421 |
+
[DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
self.output_linear = nn.Linear(d_model, tgt_vocab_size)
|
| 425 |
+
self._init_weights()
|
| 426 |
+
|
| 427 |
+
def _init_weights(self):
|
| 428 |
+
for p in self.parameters():
|
| 429 |
+
if p.dim() > 1:
|
| 430 |
+
nn.init.xavier_uniform_(p)
|
| 431 |
+
|
| 432 |
+
# ── mask helpers ──────────────────────────
|
| 433 |
+
|
| 434 |
+
def make_src_mask(self, src: torch.Tensor) -> torch.Tensor:
|
| 435 |
+
# (B, 1, 1, T_src) — 1 where not pad
|
| 436 |
+
return (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
|
| 437 |
+
|
| 438 |
+
def make_tgt_mask(self, tgt: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
T = tgt.size(1)
|
| 440 |
+
pad_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,T)
|
| 441 |
+
causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() # (T,T)
|
| 442 |
+
return pad_mask & causal # (B,1,T,T)
|
| 443 |
+
|
| 444 |
+
# ── forward ───────────────────────────────
|
| 445 |
+
|
| 446 |
+
def forward(
|
| 447 |
+
self,
|
| 448 |
+
src: torch.Tensor,
|
| 449 |
+
tgt: torch.Tensor,
|
| 450 |
+
log: Optional[CalcLog] = None,
|
| 451 |
+
) -> Tuple[torch.Tensor, Dict]:
|
| 452 |
+
src_mask = self.make_src_mask(src)
|
| 453 |
+
tgt_mask = self.make_tgt_mask(tgt)
|
| 454 |
+
|
| 455 |
+
# ── Encoder ──────────────────────────
|
| 456 |
+
src_emb = self.src_embed(src) * math.sqrt(self.d_model)
|
| 457 |
+
if log:
|
| 458 |
+
log.log("SRC_tokens", src[0],
|
| 459 |
+
note="Source token IDs (English)")
|
| 460 |
+
log.log("SRC_embedding_raw", src_emb[0, :, :8],
|
| 461 |
+
formula=f"emb = Embedding(token_id) × √{self.d_model}",
|
| 462 |
+
note="Token embeddings (first 8 dims)")
|
| 463 |
+
|
| 464 |
+
enc_x = self.src_pe(src_emb, log=log)
|
| 465 |
+
|
| 466 |
+
enc_attn_weights = []
|
| 467 |
+
for i, layer in enumerate(self.encoder_layers):
|
| 468 |
+
enc_x, ew = layer(enc_x, src_mask=src_mask, log=log, layer_idx=i)
|
| 469 |
+
enc_attn_weights.append(ew.detach().cpu().numpy())
|
| 470 |
+
|
| 471 |
+
if log:
|
| 472 |
+
log.log("ENCODER_output", enc_x[0, :, :8],
|
| 473 |
+
note="Final encoder output — passed as K,V to every decoder cross-attention")
|
| 474 |
+
|
| 475 |
+
# ── Decoder ──────────────────────────
|
| 476 |
+
tgt_emb = self.tgt_embed(tgt) * math.sqrt(self.d_model)
|
| 477 |
+
if log:
|
| 478 |
+
log.log("TGT_tokens", tgt[0],
|
| 479 |
+
note="Target token IDs (Bengali, teacher-forced in training)")
|
| 480 |
+
log.log("TGT_embedding_raw", tgt_emb[0, :, :8],
|
| 481 |
+
formula=f"emb = Embedding(token_id) × √{self.d_model}",
|
| 482 |
+
note="Bengali token embeddings")
|
| 483 |
+
|
| 484 |
+
dec_x = self.tgt_pe(tgt_emb, log=log)
|
| 485 |
+
|
| 486 |
+
dec_self_attn_w = []
|
| 487 |
+
dec_cross_attn_w = []
|
| 488 |
+
for i, layer in enumerate(self.decoder_layers):
|
| 489 |
+
dec_x, mw, cw = layer(
|
| 490 |
+
dec_x, enc_x,
|
| 491 |
+
tgt_mask=tgt_mask, src_mask=src_mask,
|
| 492 |
+
log=log, layer_idx=i,
|
| 493 |
+
)
|
| 494 |
+
dec_self_attn_w.append(mw.detach().cpu().numpy())
|
| 495 |
+
dec_cross_attn_w.append(cw.detach().cpu().numpy())
|
| 496 |
+
|
| 497 |
+
# ── Output projection ─────────────────
|
| 498 |
+
logits = self.output_linear(dec_x) # (B, T, vocab)
|
| 499 |
+
if log:
|
| 500 |
+
log.log("LOGITS", logits[0, :, :16],
|
| 501 |
+
formula="logits = dec_out · W_out (first 16 vocab entries shown)",
|
| 502 |
+
note=f"Raw scores over vocab of {logits.size(-1)} Bengali tokens")
|
| 503 |
+
|
| 504 |
+
probs = F.softmax(logits[0], dim=-1)
|
| 505 |
+
log.log("SOFTMAX_probs", probs[:, :16],
|
| 506 |
+
formula="P(token) = exp(logit) / Σ exp(logits)",
|
| 507 |
+
note="Probability distribution over Bengali vocabulary")
|
| 508 |
+
|
| 509 |
+
meta = {
|
| 510 |
+
"enc_attn": enc_attn_weights,
|
| 511 |
+
"dec_self_attn": dec_self_attn_w,
|
| 512 |
+
"dec_cross_attn": dec_cross_attn_w,
|
| 513 |
+
"src_mask": src_mask.cpu().numpy(),
|
| 514 |
+
"tgt_mask": tgt_mask.cpu().numpy(),
|
| 515 |
+
}
|
| 516 |
+
return logits, meta
|
vocab.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
vocab.py
|
| 3 |
+
Simple character/subword tokenizer for English→Bengali demo.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import re
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import List, Dict, Tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ── Special tokens ───────────────────────────────────────────────────────────
|
| 13 |
+
|
| 14 |
+
PAD_TOKEN = "<pad>"
|
| 15 |
+
BOS_TOKEN = "<bos>"
|
| 16 |
+
EOS_TOKEN = "<eos>"
|
| 17 |
+
UNK_TOKEN = "<unk>"
|
| 18 |
+
|
| 19 |
+
PAD_IDX = 0
|
| 20 |
+
BOS_IDX = 1
|
| 21 |
+
EOS_IDX = 2
|
| 22 |
+
UNK_IDX = 3
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── English word-level vocab ─────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
EN_WORDS = [
|
| 28 |
+
"i", "you", "he", "she", "we", "they", "it",
|
| 29 |
+
"love", "like", "eat", "drink", "go", "come", "see", "know",
|
| 30 |
+
"want", "need", "have", "am", "is", "are", "was", "were",
|
| 31 |
+
"do", "does", "did", "will", "can", "may", "should",
|
| 32 |
+
"a", "an", "the", "this", "that", "my", "your", "his", "her",
|
| 33 |
+
"good", "bad", "happy", "sad", "big", "small", "new", "old",
|
| 34 |
+
"food", "water", "home", "work", "school", "book", "name",
|
| 35 |
+
"rice", "fish", "milk", "tea", "coffee",
|
| 36 |
+
"hello", "bye", "yes", "no", "please", "thank", "thanks",
|
| 37 |
+
"how", "what", "where", "when", "why", "who",
|
| 38 |
+
"today", "tomorrow", "now", "always", "never", "very",
|
| 39 |
+
"bengal", "india", "english", "bengali",
|
| 40 |
+
"beautiful", "wonderful", "great", "nice", "fine",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
# ── Bengali word-level vocab ──────────────────────────────────────────────────
|
| 44 |
+
|
| 45 |
+
BN_WORDS = [
|
| 46 |
+
"আমি", "তুমি", "তুই", "সে", "আমরা", "তারা", "এটা",
|
| 47 |
+
"ভালোবাসি", "পছন্দ", "খাই", "পান", "যাই", "আসি", "দেখি", "জানি",
|
| 48 |
+
"চাই", "দরকার", "আছে", "হই", "হয়", "ছিলাম", "ছিল",
|
| 49 |
+
"করি", "করে", "করেছি", "করব", "পারি", "পারে", "উচিত",
|
| 50 |
+
"একটা", "একটি", "এই", "সেই", "আমার", "তোমার", "তার",
|
| 51 |
+
"ভালো", "খারাপ", "খুশি", "দুঃখী", "বড়", "ছোট", "নতুন", "পুরনো",
|
| 52 |
+
"খাবার", "পানি", "বাড়ি", "কাজ", "স্কুল", "বই", "নাম",
|
| 53 |
+
"ভাত", "মাছ", "দুধ", "চা", "কফি",
|
| 54 |
+
"হ্যালো", "বিদায়", "হ্যাঁ", "না", "দয়া", "ধন্যবাদ",
|
| 55 |
+
"কেমন", "কি", "কোথায়", "কখন", "কেন", "কে",
|
| 56 |
+
"আজ", "আগামীকাল", "এখন", "সবসময়", "কখনো", "খুব",
|
| 57 |
+
"বাংলা", "ভারত", "ইংরেজি",
|
| 58 |
+
"সুন্দর", "চমৎকার", "দারুণ",
|
| 59 |
+
"তোমাকে", "আপনাকে", "তাকে", "আমাকে",
|
| 60 |
+
"করছি", "করছে", "হচ্ছে", "পড়ি", "লিখি", "বলি",
|
| 61 |
+
"আছ", "সকাল", "জানে", "দেখে",
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Vocab:
|
| 66 |
+
def __init__(self, words: List[str], name: str = "vocab"):
|
| 67 |
+
self.name = name
|
| 68 |
+
self.token2idx: Dict[str, int] = {
|
| 69 |
+
PAD_TOKEN: PAD_IDX,
|
| 70 |
+
BOS_TOKEN: BOS_IDX,
|
| 71 |
+
EOS_TOKEN: EOS_IDX,
|
| 72 |
+
UNK_TOKEN: UNK_IDX,
|
| 73 |
+
}
|
| 74 |
+
for w in words:
|
| 75 |
+
if w not in self.token2idx:
|
| 76 |
+
self.token2idx[w] = len(self.token2idx)
|
| 77 |
+
self.idx2token: Dict[int, str] = {v: k for k, v in self.token2idx.items()}
|
| 78 |
+
|
| 79 |
+
def __len__(self):
|
| 80 |
+
return len(self.token2idx)
|
| 81 |
+
|
| 82 |
+
def encode(self, sentence: str) -> List[int]:
|
| 83 |
+
tokens = sentence.lower().strip().split()
|
| 84 |
+
ids = [self.token2idx.get(t, UNK_IDX) for t in tokens]
|
| 85 |
+
return [BOS_IDX] + ids + [EOS_IDX]
|
| 86 |
+
|
| 87 |
+
def decode(self, ids: List[int], skip_special: bool = True) -> str:
|
| 88 |
+
skip = {PAD_IDX, BOS_IDX, EOS_IDX} if skip_special else set()
|
| 89 |
+
return " ".join(
|
| 90 |
+
self.idx2token.get(i, UNK_TOKEN)
|
| 91 |
+
for i in ids
|
| 92 |
+
if i not in skip
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def tokens(self, ids: List[int]) -> List[str]:
|
| 96 |
+
return [self.idx2token.get(i, UNK_TOKEN) for i in ids]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ── Toy parallel corpus ──────────────────────────────────────────────────────
|
| 100 |
+
|
| 101 |
+
PARALLEL_DATA: List[Tuple[str, str]] = [
|
| 102 |
+
("i love you", "আমি তোমাকে ভালোবাসি"),
|
| 103 |
+
("i like food", "আমি খাবার পছন্দ"),
|
| 104 |
+
("you are good", "তুমি ভালো"),
|
| 105 |
+
("he is happy", "সে খুশি"),
|
| 106 |
+
("i want water", "আমি পানি চাই"),
|
| 107 |
+
("she is beautiful", "সে সুন্দর"),
|
| 108 |
+
("i eat rice", "আমি ভাত খাই"),
|
| 109 |
+
("i drink tea", "আমি চা পান"),
|
| 110 |
+
("i know you", "আমি তোমাকে জানি"),
|
| 111 |
+
("we are happy", "আমরা খুশি"),
|
| 112 |
+
("i see you", "আমি তোমাকে দেখি"),
|
| 113 |
+
("you are beautiful", "তুমি সুন্দর"),
|
| 114 |
+
("i love bengali", "আমি বাংলা ভালোবাসি"),
|
| 115 |
+
("hello how are you", "হ্যালো কেমন আছ"), # আছ added to vocab
|
| 116 |
+
("thank you very much", "ধন্যবাদ খুব"),
|
| 117 |
+
("i need you", "আমি তোমাকে দরকার"),
|
| 118 |
+
("he knows bengali", "সে বাংলা জানে"), # জানে added to vocab
|
| 119 |
+
("she drinks milk", "সে দুধ পান"),
|
| 120 |
+
("we go home", "আমরা বাড়ি যাই"),
|
| 121 |
+
("i am happy today", "আমি আজ খুশি"),
|
| 122 |
+
("i love my home", "আমি আমার বাড়ি ভালোবাসি"),
|
| 123 |
+
("she sees you", "সে তোমাকে দেখে"),
|
| 124 |
+
("they eat rice", "তারা ভাত খাই"),
|
| 125 |
+
("i want tea", "আমি চা চাই"),
|
| 126 |
+
("this is good", "এটা ভালো"),
|
| 127 |
+
("he is sad", "সে দুঃখী"),
|
| 128 |
+
("i like school", "আমি স্কুল পছন্দ"),
|
| 129 |
+
("you know me", "তুমি আমাকে জানে"),
|
| 130 |
+
("good morning", "ভালো সকাল"),
|
| 131 |
+
("i love india", "আমি ভারত ভালোবাসি"),
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def build_vocabs() -> Tuple[Vocab, Vocab]:
|
| 136 |
+
src_v = Vocab(EN_WORDS, "english")
|
| 137 |
+
tgt_v = Vocab(BN_WORDS, "bengali")
|
| 138 |
+
return src_v, tgt_v
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# singleton
|
| 142 |
+
_src_vocab, _tgt_vocab = build_vocabs()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def get_vocabs() -> Tuple[Vocab, Vocab]:
|
| 146 |
+
return _src_vocab, _tgt_vocab
|