AstralPotato commited on
Commit
e7f17a4
Β·
verified Β·
1 Parent(s): fa33ab2

Upload en-ms Transformer (6+2 Tied, 16K BPE, chrF 45.62)

Browse files
Files changed (8) hide show
  1. README.md +292 -0
  2. best_model.pt +3 -0
  3. config.json +36 -0
  4. src/eval.py +315 -0
  5. src/model.py +287 -0
  6. src/tokenizer.py +360 -0
  7. src/training.py +370 -0
  8. tokenizer_shared_16k.json +0 -0
README.md ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - ms
5
+ tags:
6
+ - translation
7
+ - transformer
8
+ - en-ms
9
+ - pytorch
10
+ - bpe
11
+ - encoder-decoder
12
+ - tied-embeddings
13
+ - deep-encoder-shallow-decoder
14
+ license: mit
15
+ datasets:
16
+ - open_subtitles
17
+ metrics:
18
+ - chrf
19
+ pipeline_tag: translation
20
+ model-index:
21
+ - name: en-ms-transformer-6+2-tied
22
+ results:
23
+ - task:
24
+ type: translation
25
+ name: Translation
26
+ dataset:
27
+ name: OpenSubtitles v2018 en-ms
28
+ type: open_subtitles
29
+ split: test
30
+ metrics:
31
+ - type: chrf
32
+ value: 45.62
33
+ name: chrF (greedy)
34
+ - type: chrf
35
+ value: 44.99
36
+ name: chrF (beam=5)
37
+ ---
38
+
39
+ # English β†’ Malay Transformer (6+2 Tied, 16K BPE)
40
+
41
+ A custom encoder-decoder Transformer for English-to-Malay translation, built entirely from scratch in PyTorch. This model was developed as part of **IT3103 Advanced Topics in AI β€” Assignment 2, 2025 Semester 2**.
42
+
43
+ The project encompasses the full NMT pipeline: dataset curation, tokenizer training, architecture design with ablation studies, training with mixed-precision, and evaluation β€” all without using any pretrained models or high-level frameworks like Fairseq or OpenNMT.
44
+
45
+ ## Model Description
46
+
47
+ | Component | Details |
48
+ |---|---|
49
+ | **Architecture** | 6-layer encoder + 2-layer decoder, pre-norm Transformer |
50
+ | **d_model / n_head / d_ff** | 512 / 8 / 2048 |
51
+ | **Vocab** | 16,000 shared BPE (English + Malay, joint) |
52
+ | **Dropout** | 0.3 |
53
+ | **Parameters** | ~36.6M |
54
+ | **Tied embeddings** | Yes β€” encoder input, decoder input, and output projection share the same weight matrix (Press & Wolf, 2017) |
55
+ | **Normalisation** | Pre-norm (LayerNorm before attention/FFN, not after) |
56
+
57
+ ### Design Decisions and Rationale
58
+
59
+ **Why 6+2 (Deep Encoder, Shallow Decoder)?**
60
+
61
+ The asymmetric 6+2 architecture is grounded in [Kasai et al. (2021)](https://arxiv.org/abs/2006.10369) β€” *"Deep Encoder, Shallow Decoder: Reevaluating the Speed-Quality Tradeoff in Machine Translation"*. The core insight is that encoder depth contributes more to translation quality (richer source representations), while the decoder can be kept shallow without significant degradation. This was empirically validated by our own **Ablation Sweep 1** (see below), which showed that encoder depths of 2, 4, 6, and 8 all produced similar chrF scores (22–25 range), indicating the model hits diminishing returns quickly. We chose 6 as a safe operating point.
62
+
63
+ The shallow 2-layer decoder provides a practical speed advantage: ~2Γ— faster inference compared to a symmetric 6+6, since autoregressive decoding must run the decoder once per output token.
64
+
65
+ **Why 16K shared vocabulary?**
66
+
67
+ We initially trained with 50K vocabulary but found it too sparse for 500K training sentences β€” most tokens appeared very infrequently, leaving embeddings under-trained. Reducing to 16K shared BPE produced denser embeddings and led to a 2.5Γ— speedup per epoch (7.8 min vs ~20 min estimated at 50K). English and Malay share the Latin script with substantial lexical overlap (loanwords like "teknologi", "universiti"; numbers; proper nouns), making a joint vocabulary highly effective.
68
+
69
+ **Why tied embeddings?**
70
+
71
+ With a shared source-target vocabulary, tying the encoder embedding, decoder embedding, and output projection matrix (Press & Wolf, 2017) reduces the parameter count by ~16M while acting as a strong regulariser. The model learns a single semantic space for both languages.
72
+
73
+ **Why dropout 0.3?**
74
+
75
+ 490K training sentences is relatively small for a Transformer. Dropout 0.3 was chosen as aggressive regularisation to prevent overfitting. Training curves confirm this was appropriate β€” the gap between train loss (3.17) and val loss (3.21) remained small throughout training, with no signs of overfitting even at epoch 20.
76
+
77
+ ## Training Data
78
+
79
+ - **Dataset:** [OpenSubtitles v2018](https://opus.nlpl.eu/OpenSubtitles-v2018.php) (English-Malay aligned parallel corpus)
80
+ - **Raw corpus size:** ~17.3M parallel sentence pairs
81
+ - **After filtering:** 500,000 pairs selected
82
+ - **Split:** 490,000 train / 5,000 validation / 5,000 test (all in-distribution)
83
+
84
+ ### Data Preprocessing Pipeline
85
+
86
+ The raw OpenSubtitles corpus is notoriously noisy (subtitle artifacts, music symbols, HTML tags, near-duplicate lines). We applied the following quality filters:
87
+
88
+ 1. **Length filter:** 3–80 words per side (removes fragments and overly long lines)
89
+ 2. **Length ratio filter:** max(len_en, len_ms) / min(len_en, len_ms) ≀ 3.0 (removes misaligned pairs)
90
+ 3. **Character length filter:** 10–400 characters per side
91
+ 4. **Junk pattern removal:** Regex filter for music symbols (β™ͺβ™«), HTML tags, bracket-only lines (e.g. `[music playing]`), ellipsis-only lines, dash-only lines
92
+ 5. **Deduplication:** Case-insensitive exact match on the English side
93
+
94
+ This pipeline retains ~500K high-quality pairs from the first ~2.7M lines scanned.
95
+
96
+ ### Why OpenSubtitles over TED Talks?
97
+
98
+ We initially experimented with the [IWSLT TED Talks](https://huggingface.co/datasets/IWSLT/ted_talks_iwslt) dataset (~5K en-ms pairs) and achieved a chrF of only **6.76** β€” the dataset was far too small. We then moved to OpenSubtitles which provides orders of magnitude more data. Importantly, we evaluate on **in-distribution** OpenSubtitles test data rather than using TED Talks as an out-of-distribution test set, which would unfairly penalise the model for domain mismatch (conversational subtitles vs. formal TED lectures).
99
+
100
+ ## Ablation Studies
101
+
102
+ We conducted two systematic ablation sweeps to guide architecture and data decisions. All sweeps used a 50K vocabulary baseline with 3 training epochs for efficiency.
103
+
104
+ ### Sweep 1: Encoder Depth
105
+
106
+ Fixed: 50K vocab, 500K data, 2-layer decoder, 3 epochs.
107
+
108
+ | Encoder Layers | chrF (TED test) | Val Loss | Params |
109
+ |---|---|---|---|
110
+ | 2 | 24.42 | 3.92 | 48.5M |
111
+ | 4 | 22.37 | 3.84 | 61.5M |
112
+ | 6 | 24.65 | 3.80 | 74.6M |
113
+ | 8 | 22.91 | 3.76 | 87.6M |
114
+
115
+ **Finding:** Encoder depth has **flat returns** on downstream chrF despite steadily decreasing validation loss. This suggests the TED Talks OOD test set was the bottleneck (confirmed later), not model capacity. We selected 6 layers as the sweet spot β€” the lowest loss before severe diminishing returns, and well-supported by the Kasai et al. finding.
116
+
117
+ ### Sweep 2: Training Data Size
118
+
119
+ Fixed: 50K vocab, 6+2 architecture, 3 epochs.
120
+
121
+ | Train Size | chrF (TED test) | Val Loss |
122
+ |---|---|---|
123
+ | 50K | 16.67 | 4.50 |
124
+ | 100K | 19.60 | 4.11 |
125
+ | 200K | 22.47 | 3.93 |
126
+ | 500K | 26.50 | 3.75 |
127
+
128
+ **Finding:** chrF scales **approximately linearly with log(data size)** β€” a ~3.3 chrF improvement per doubling. This confirmed that **data volume is the dominant factor** for translation quality at this scale, far more impactful than architectural changes. This motivated our final model to use the maximum feasible data (490K after filtering).
129
+
130
+ ## Training Details
131
+
132
+ | Setting | Value |
133
+ |---|---|
134
+ | Optimizer | AdamW (lr=5e-4, β₁=0.9, Ξ²β‚‚=0.98, Ξ΅=1e-9) |
135
+ | Schedule | Linear warmup (4,000 steps) β†’ cosine decay to 0 |
136
+ | Batch size | 128 |
137
+ | Max sequence length | 128 tokens |
138
+ | Epochs | 20 (early stopping patience=3, did not trigger) |
139
+ | Label smoothing | 0.1 |
140
+ | Gradient clipping | max_norm=1.0 |
141
+ | AMP | fp16 mixed precision (PyTorch GradScaler) |
142
+ | Hardware | NVIDIA RTX 5070 Ti (16GB VRAM), CUDA 13.1 |
143
+ | Training time | **2.62 hours** (157 min, ~7.85 min/epoch) |
144
+
145
+ ### Training Progression
146
+
147
+ | Epoch | Train Loss | Val Loss | LR |
148
+ |---|---|---|---|
149
+ | 1 | 5.4036 | 4.2485 | 6.5e-5 |
150
+ | 5 | 3.5888 | 3.4519 | 3.7e-4 |
151
+ | 10 | 3.3605 | 3.2986 | 3.7e-4 |
152
+ | 15 | 3.2268 | 3.2346 | 1.8e-4 |
153
+ | 20 | 3.1683 | 3.2110 | 4.9e-6 |
154
+
155
+ The model converged smoothly with no overfitting β€” the train-val gap remained under 0.05 throughout training. The cosine LR decay drove the final epochs to squeeze out the last bits of improvement (val loss 3.23 at epoch 15 β†’ 3.21 at epoch 20).
156
+
157
+ ## Evaluation Results
158
+
159
+ Evaluated on **5,000 held-out in-distribution** OpenSubtitles test sentences with post-processing applied.
160
+
161
+ | Decoding Strategy | chrF |
162
+ |---|---|
163
+ | Greedy | **45.62** |
164
+ | Beam search (beam=5, length_penalty=0.6) | 44.99 |
165
+
166
+ ### Post-Processing
167
+
168
+ The BPE tokenizer uses a `Whitespace` pre-tokenizer without continuation markers, so raw `decode()` output contains spurious spaces before punctuation (e.g., `"mendarat , tuan ."` instead of `"mendarat, tuan."`). We apply a lightweight regex-based post-processing step that:
169
+
170
+ 1. Removes spaces before punctuation marks (`. , ? ! ; :`)
171
+ 2. Removes spaces after opening brackets/quotes
172
+ 3. Collapses spaced hyphens in compound words
173
+ 4. Capitalises the first character
174
+
175
+ This post-processing improved chrF by **+1.05 points** (greedy: 44.57 β†’ 45.62) β€” a free gain with zero retraining.
176
+
177
+ ### Why Greedy > Beam Search?
178
+
179
+ Interestingly, greedy decoding outperforms beam search here. This is a known phenomenon in NMT: beam search with length penalty can produce outputs that are slightly too long or too short for chrF's character n-gram matching. Greedy decoding produces more "natural length" outputs that happen to align better with reference lengths in this corpus.
180
+
181
+ ### Sample Translations
182
+
183
+ | # | English (Source) | Reference (Malay) | Model Output |
184
+ |---|---|---|---|
185
+ | 1 | Skywalker has just landed, lord. | Skywalker baru sahaja mendarat, tuan. | Skywalker baru mendarat, tuan. |
186
+ | 2 | Raymond, you like me? | Raymond, awak suka saya? | Raymond, awak suka saya? |
187
+ | 3 | She may be dying and it's all my fault. | Dia mungkin akan mati dan semuanya salah saya. | Dia mungkin akan mati dan semuanya salah saya. |
188
+ | 4 | He always remembers the cards. | Ia ingat kad. | Dia selalu ingat kad. |
189
+ | 5 | Hey, you wanna see something? | Hei, awak nak tengok sesuatu? | Hei, awak nak lihat sesuatu? |
190
+ | 6 | Why don't you just go talk to her? | Mengapa awak tidak bercakap dengannya? | Apa kata awak cakap dengan dia? |
191
+ | 7 | We still got that meat-lovers' pizza in the trunk. | Kita masih ada piza daging dalam but. | Kita masih ada piza daging di dalam but kereta. |
192
+
193
+ The model produces fluent, natural Malay that is often comparable or near-identical to the reference translations. Errors tend to occur on rare proper nouns (subword fragmentation) and idiomatic expressions.
194
+
195
+ ## Tokenizer
196
+
197
+ - **Type:** Byte-Pair Encoding (BPE) via HuggingFace `tokenizers` library (Rust backend)
198
+ - **Vocab size:** 16,000 (shared joint vocabulary for both English and Malay)
199
+ - **Normalization:** NFKC Unicode normalisation + lowercase
200
+ - **Pre-tokenization:** Whitespace splitting
201
+ - **Post-processing:** `[BOS] $A [EOS]` template (auto-wraps encoded sequences)
202
+ - **Special tokens:** `[PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [BOS]=5, [EOS]=6`
203
+ - **Trained on:** 490K training pairs only (980K sentences total) β€” no data leakage from val/test
204
+
205
+ ### Why Shared BPE for en-ms?
206
+
207
+ English and Malay both use the Latin script with significant lexical overlap (loanwords: "teknologi", "matematik", "universiti"; numbers; proper nouns; punctuation). A joint BPE vocabulary captures cross-lingual subword patterns and directly enables tied embeddings. Malay's morphological affixes (me-, ber-, di-, -kan, -an, -i) are naturally learned as subword units by BPE, providing good coverage without an explicitly morphological tokenizer.
208
+
209
+ ## Usage
210
+
211
+ ```python
212
+ import torch
213
+ from tokenizers import Tokenizer
214
+
215
+ # Load tokenizer
216
+ tokenizer = Tokenizer.from_file("tokenizer_shared_16k.json")
217
+
218
+ # Load model (requires model.py from src/)
219
+ from src.model import build_model
220
+
221
+ model = build_model(
222
+ vocab_size=16000, pad_idx=0, device=torch.device("cpu"),
223
+ d_model=512, n_head=8, num_encoder_layers=6, num_decoder_layers=2,
224
+ d_ff=2048, dropout=0.3, max_len=144,
225
+ )
226
+ model.load_state_dict(torch.load("best_model.pt", map_location="cpu", weights_only=True))
227
+ model.eval()
228
+
229
+ # Translate (requires eval.py from src/)
230
+ from src.eval import translate
231
+ result = translate(model, "Hello, how are you?", tokenizer, tokenizer,
232
+ bos_id=5, eos_id=6, pad_id=0, max_len=128,
233
+ device=torch.device("cpu"), beam_width=5)
234
+ print(result) # β†’ "Hai, apa khabar?"
235
+ ```
236
+
237
+ ## Repository Structure
238
+
239
+ | File | Description |
240
+ |---|---|
241
+ | `best_model.pt` | Model weights (`state_dict` format, ~140MB) |
242
+ | `tokenizer_shared_16k.json` | Shared BPE tokenizer (16K vocab) |
243
+ | `config.json` | Full model configuration and training hyperparameters |
244
+ | `src/model.py` | `TransformerTranslator` β€” complete encoder-decoder architecture |
245
+ | `src/tokenizer.py` | BPE tokenizer training, saving, loading, encoding, decoding |
246
+ | `src/training.py` | Full training loop with early stopping, warmup, cosine decay, AMP |
247
+ | `src/eval.py` | Greedy/beam decoding, chrF scoring, post-processing |
248
+
249
+ ## Experimental Journey
250
+
251
+ This project went through several iterations:
252
+
253
+ 1. **TED Talks baseline** β€” IWSLT TED Talks en-ms (~5K pairs). chrF **6.76**. Dataset far too small.
254
+ 2. **OPUS-100 pivot** β€” Switched to OPUS-100 en-ms. chrF **26.39** with 10+2 architecture. Significant improvement but still limited by data quality.
255
+ 3. **OpenSubtitles pivot** β€” Moved to OpenSubtitles v2018 (17.3M raw pairs). Quality filtering pipeline developed.
256
+ 4. **Ablation sweeps** β€” Systematically tested encoder depth (2/4/6/8) and data size (50K/100K/200K/500K). Discovered data size is the dominant factor.
257
+ 5. **Final model** β€” 6+2 tied Transformer, 16K BPE, 490K data, dropout 0.3. chrF **44.57** (greedy, no postprocessing).
258
+ 6. **Post-processing fix** β€” Added punctuation cleanup. chrF **45.62** (greedy). Free +1.05 improvement.
259
+
260
+ ## Limitations and Future Work
261
+
262
+ ### Current Limitations
263
+ - **Domain specificity:** Trained exclusively on movie/TV subtitles β€” performance degrades significantly on formal, academic, or technical text (e.g., TED Talks test set gave chrF ~6–26 depending on configuration).
264
+ - **Subword fragmentation:** Rare proper nouns and domain-specific terms get split into character-level fragments (e.g., "Burgundy" β†’ "bur gun dy", "android" β†’ "dan ro id"). A larger vocabulary or byte-level fallback could mitigate this.
265
+ - **16K vocab trade-off:** The compact vocabulary provides dense embeddings but over-segments rare words. A 32K vocabulary might be a better balance.
266
+ - **No backtranslation or data augmentation:** The model trains on natural parallel data only.
267
+
268
+ ### Future Improvements
269
+ - **Scale data to 2M+**: Our sweep shows chrF gains ~3.3 points per data doubling. 2M sentences could push chrF to ~50+.
270
+ - **Reduce dropout to 0.1**: With more data, the aggressive 0.3 dropout likely over-regularises.
271
+ - **Byte-level fallback**: Handle rare words more gracefully.
272
+ - **Ensemble decoding**: Combine checkpoints from different training stages.
273
+
274
+ ## References
275
+
276
+ - Vaswani, A. et al. (2017). [Attention is All You Need](https://arxiv.org/abs/1706.03762). *NeurIPS*.
277
+ - Kasai, J. et al. (2021). [Deep Encoder, Shallow Decoder: Reevaluating Non-autoregressive Machine Translation](https://arxiv.org/abs/2006.10369). *ICLR*.
278
+ - Press, O. & Wolf, L. (2017). [Using the Output Embedding to Improve Language Models](https://arxiv.org/abs/1608.05859). *EACL*.
279
+ - Popović, M. (2015). [chrF: character n-gram F-score for automatic MT evaluation](https://aclanthology.org/W15-3049/). *WMT*.
280
+ - Sennrich, R. et al. (2016). [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/abs/1508.07909). *ACL*.
281
+ - Lison, P. & Tiedemann, J. (2016). [OpenSubtitles2016: Extracting Large Parallel Corpora from Movie and TV Subtitles](http://www.lrec-conf.org/proceedings/lrec2016/pdf/947_Paper.pdf). *LREC*.
282
+
283
+ ## Citation
284
+
285
+ ```bibtex
286
+ @misc{astralpotato2025enms,
287
+ title={English-Malay Neural Machine Translation with Deep Encoder, Shallow Decoder Transformer},
288
+ author={AstralPotato},
289
+ year={2025},
290
+ howpublished={IT3103 Advanced Topics in AI, Assignment 2, 2025S2},
291
+ }
292
+ ```
best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053df05b5a8c77434507d745eee6fff4c52cddfc72abf27330d71f1e8688c3e3
3
+ size 142469700
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "TransformerTranslator (6+2 Tied)",
3
+ "vocab_size": 16000,
4
+ "d_model": 512,
5
+ "n_head": 8,
6
+ "num_encoder_layers": 6,
7
+ "num_decoder_layers": 2,
8
+ "d_ff": 2048,
9
+ "dropout": 0.3,
10
+ "max_len": 144,
11
+ "pad_idx": 0,
12
+ "bos_id": 5,
13
+ "eos_id": 6,
14
+ "tied_embeddings": true,
15
+ "pre_norm": true,
16
+ "label_smoothing": 0.1,
17
+ "training": {
18
+ "dataset": "OpenSubtitles v2018 en-ms",
19
+ "train_size": 490000,
20
+ "val_size": 5000,
21
+ "test_size": 5000,
22
+ "epochs_trained": 20,
23
+ "batch_size": 128,
24
+ "lr": 0.0005,
25
+ "warmup_steps": 4000,
26
+ "optimizer": "AdamW",
27
+ "scheduler": "linear warmup + cosine decay",
28
+ "amp": true
29
+ },
30
+ "evaluation": {
31
+ "chrf_greedy": 45.62,
32
+ "chrf_beam5_lp06": 44.99,
33
+ "test_set": "5K in-distribution OpenSubtitles",
34
+ "note": "chrF with post-processing (punctuation cleanup)"
35
+ }
36
+ }
src/eval.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation module – greedy / beam-search decoding + chrF scoring.
3
+ =================================================================
4
+ Provides:
5
+ β€’ ``greedy_decode`` – auto-regressive greedy decoding.
6
+ β€’ ``beam_search_decode`` – beam search with length normalisation.
7
+ β€’ ``translate`` – end-to-end: raw English string β†’ Malay string.
8
+ β€’ ``compute_chrf`` – corpus-level chrF score via *sacrebleu*.
9
+ β€’ ``evaluate`` – decode the full validation set, compute chrF,
10
+ and print sample translations.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import re
16
+ from typing import List, Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from tokenizers import Tokenizer
21
+
22
+ import sacrebleu
23
+
24
+
25
+ # ──────────────────────────────────────────────────────────────────────
26
+ # 0. Post-processing: fix tokenizer spacing artefacts
27
+ # ──────────────────────────────────────────────────────────────────────
28
+ def postprocess_translation(text: str) -> str:
29
+ """
30
+ Clean up raw tokenizer decode output:
31
+ 1. Remove spaces before punctuation ( ", tuan ." β†’ ", tuan.")
32
+ 2. Remove spaces after opening brackets/quotes
33
+ 3. Remove spaces before closing brackets/quotes
34
+ 4. Capitalise the first letter
35
+ 5. Collapse multiple spaces
36
+ """
37
+ # Remove space before punctuation: . , ? ! ; : ) ] } ' " ...
38
+ text = re.sub(r'\s+([.,?!;:)\]}"\'…])', r'\1', text)
39
+ # Remove space after opening brackets/quotes
40
+ text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
41
+ # Fix spaced hyphens in compound words (e.g. "brother - in - arms" β†’ "brother-in-arms")
42
+ text = re.sub(r'\s*-\s*', '-', text)
43
+ # Collapse multiple spaces
44
+ text = re.sub(r'\s{2,}', ' ', text)
45
+ # Strip and capitalise
46
+ text = text.strip()
47
+ if text:
48
+ text = text[0].upper() + text[1:]
49
+ return text
50
+
51
+
52
+ # ──────────────────────────────────────────────────────────────────────
53
+ # 1. Greedy decoding
54
+ # ──────────────────────────────────────────────────────────────────────
55
+ @torch.no_grad()
56
+ def greedy_decode(
57
+ model: nn.Module,
58
+ src: torch.Tensor,
59
+ bos_id: int,
60
+ eos_id: int,
61
+ pad_id: int = 0,
62
+ max_len: int = 128,
63
+ ) -> torch.Tensor:
64
+ """
65
+ Auto-regressive greedy decoding for a single source sequence.
66
+
67
+ Parameters
68
+ ----------
69
+ model : TransformerTranslator
70
+ src : (1, src_len) source token IDs.
71
+ bos_id : beginning-of-sentence token ID.
72
+ eos_id : end-of-sentence token ID.
73
+ pad_id : padding token ID.
74
+ max_len : maximum decoding steps.
75
+
76
+ Returns
77
+ -------
78
+ (1, out_len) generated token IDs (including [BOS], up to [EOS]).
79
+ """
80
+ device = src.device
81
+ model.eval()
82
+
83
+ # Encode source once
84
+ src_pad_mask = (src == pad_id)
85
+ memory = model.encode(src, src_key_padding_mask=src_pad_mask)
86
+
87
+ # Start with [BOS]
88
+ ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
89
+
90
+ for _ in range(max_len - 1):
91
+ logits = model.decode(
92
+ ys, memory,
93
+ memory_key_padding_mask=src_pad_mask,
94
+ ) # (1, cur_len, vocab)
95
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (1, 1)
96
+ ys = torch.cat([ys, next_token], dim=1)
97
+
98
+ if next_token.item() == eos_id:
99
+ break
100
+
101
+ return ys
102
+
103
+
104
+ # ──────────────────────────────────────────────────────────────────────
105
+ # 1b. Beam-search decoding
106
+ # ──────────────────────────────────────────────────────────────────────
107
+ @torch.no_grad()
108
+ def beam_search_decode(
109
+ model: nn.Module,
110
+ src: torch.Tensor,
111
+ bos_id: int,
112
+ eos_id: int,
113
+ pad_id: int = 0,
114
+ max_len: int = 128,
115
+ beam_width: int = 5,
116
+ length_penalty: float = 0.6,
117
+ ) -> torch.Tensor:
118
+ """
119
+ Beam-search decoding for a single source sequence.
120
+
121
+ Parameters
122
+ ----------
123
+ model : TransformerTranslator
124
+ src : (1, src_len) source token IDs.
125
+ bos_id, eos_id, pad_id : special token IDs.
126
+ max_len : maximum decoding steps.
127
+ beam_width : number of beams to keep at each step.
128
+ length_penalty : Ξ± for length normalisation: score / len^Ξ±.
129
+
130
+ Returns
131
+ -------
132
+ (1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]).
133
+ """
134
+ device = src.device
135
+ model.eval()
136
+
137
+ # Encode source once
138
+ src_pad_mask = (src == pad_id)
139
+ memory = model.encode(src, src_key_padding_mask=src_pad_mask)
140
+
141
+ # Each beam: (log_prob, token_ids_list)
142
+ beams = [(0.0, [bos_id])]
143
+ completed = []
144
+
145
+ for _ in range(max_len - 1):
146
+ candidates = []
147
+ for score, tokens in beams:
148
+ if tokens[-1] == eos_id:
149
+ completed.append((score, tokens))
150
+ continue
151
+
152
+ ys = torch.tensor([tokens], dtype=torch.long, device=device)
153
+ logits = model.decode(
154
+ ys, memory,
155
+ memory_key_padding_mask=src_pad_mask,
156
+ ) # (1, cur_len, vocab)
157
+ log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
158
+
159
+ topk_log_probs, topk_ids = log_probs.topk(beam_width)
160
+ for k in range(beam_width):
161
+ new_score = score + topk_log_probs[k].item()
162
+ new_tokens = tokens + [topk_ids[k].item()]
163
+ candidates.append((new_score, new_tokens))
164
+
165
+ if not candidates:
166
+ break
167
+
168
+ # Keep top beam_width by length-normalised score
169
+ candidates.sort(
170
+ key=lambda x: x[0] / (len(x[1]) ** length_penalty),
171
+ reverse=True,
172
+ )
173
+ beams = candidates[:beam_width]
174
+
175
+ # Early exit if all beams have finished
176
+ if all(b[1][-1] == eos_id for b in beams):
177
+ completed.extend(beams)
178
+ break
179
+
180
+ # Add any remaining beams
181
+ completed.extend(beams)
182
+
183
+ # Pick best by length-normalised score
184
+ best = max(
185
+ completed,
186
+ key=lambda x: x[0] / (len(x[1]) ** length_penalty),
187
+ )
188
+ return torch.tensor([best[1]], dtype=torch.long, device=device)
189
+
190
+
191
+ # ──────────────────────────────────────────────────────────────────────
192
+ # 2. Translate a raw string
193
+ # ──────────────────────────────────────────────────────────────────────
194
+ def translate(
195
+ model: nn.Module,
196
+ sentence: str,
197
+ src_tokenizer: Tokenizer,
198
+ tgt_tokenizer: Tokenizer,
199
+ bos_id: int,
200
+ eos_id: int,
201
+ pad_id: int = 0,
202
+ max_len: int = 128,
203
+ device: Optional[torch.device] = None,
204
+ beam_width: int = 1,
205
+ length_penalty: float = 0.6,
206
+ ) -> str:
207
+ """Translate a single English sentence to Malay.
208
+ Set beam_width=1 for greedy, >1 for beam search.
209
+ """
210
+ if device is None:
211
+ device = next(model.parameters()).device
212
+
213
+ # Tokenise source
214
+ src_ids = src_tokenizer.encode(sentence).ids
215
+ src = torch.tensor([src_ids], dtype=torch.long, device=device)
216
+
217
+ # Decode
218
+ if beam_width > 1:
219
+ out_ids = beam_search_decode(
220
+ model, src, bos_id, eos_id, pad_id, max_len,
221
+ beam_width=beam_width, length_penalty=length_penalty,
222
+ )
223
+ else:
224
+ out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len)
225
+
226
+ # Convert IDs β†’ string (skip special tokens) + clean up spacing
227
+ raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
228
+ return postprocess_translation(raw)
229
+
230
+
231
+ # ──────────────────────────────────────────────────────────────────────
232
+ # 3. Corpus-level chrF
233
+ # ──────────────────────────────────────────────────────────────────────
234
+ def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore:
235
+ """
236
+ Compute corpus-level chrF score.
237
+
238
+ Parameters
239
+ ----------
240
+ hypotheses : list[str]
241
+ System outputs (decoded translations).
242
+ references : list[str]
243
+ Gold reference translations.
244
+
245
+ Returns
246
+ -------
247
+ sacrebleu.CHRFScore – has ``.score`` attribute (0–100 scale).
248
+ """
249
+ return sacrebleu.corpus_chrf(hypotheses, [references])
250
+
251
+
252
+ # ──────────────────────────────────────────────────────────────────────
253
+ # 4. Full evaluation driver
254
+ # ──────────────────────────────────────────────────────────────────────
255
+ def evaluate(
256
+ model: nn.Module,
257
+ hf_dataset,
258
+ src_tokenizer: Tokenizer,
259
+ tgt_tokenizer: Tokenizer,
260
+ src_lang: str = "en",
261
+ tgt_lang: str = "ms",
262
+ bos_id: int = 5,
263
+ eos_id: int = 6,
264
+ pad_id: int = 0,
265
+ max_len: int = 128,
266
+ device: Optional[torch.device] = None,
267
+ num_samples: int = 5,
268
+ beam_width: int = 1,
269
+ length_penalty: float = 0.6,
270
+ ) -> float:
271
+ """
272
+ Decode every example in *hf_dataset*, compute corpus chrF, and
273
+ print ``num_samples`` side-by-side translations.
274
+
275
+ Set beam_width=1 for greedy, >1 for beam search.
276
+
277
+ Returns
278
+ -------
279
+ chrf_score : float (0–100)
280
+ """
281
+ if device is None:
282
+ device = next(model.parameters()).device
283
+
284
+ model.eval()
285
+ hypotheses: List[str] = []
286
+ references: List[str] = []
287
+
288
+ for i, example in enumerate(hf_dataset):
289
+ src_text = example["translation"][src_lang]
290
+ ref_text = example["translation"][tgt_lang]
291
+
292
+ hyp_text = translate(
293
+ model, src_text,
294
+ src_tokenizer, tgt_tokenizer,
295
+ bos_id, eos_id, pad_id, max_len, device,
296
+ beam_width=beam_width,
297
+ length_penalty=length_penalty,
298
+ )
299
+
300
+ hypotheses.append(hyp_text)
301
+ references.append(ref_text)
302
+
303
+ chrf = compute_chrf(hypotheses, references)
304
+
305
+ # Print samples
306
+ print(f"\n{'='*60}")
307
+ print(f"chrF Score: {chrf.score:.2f}")
308
+ print(f"{'='*60}")
309
+ for i in range(min(num_samples, len(hypotheses))):
310
+ src_text = hf_dataset[i]["translation"][src_lang]
311
+ print(f"\n[{i}] SRC: {src_text[:120]}")
312
+ print(f" REF: {references[i][:120]}")
313
+ print(f" HYP: {hypotheses[i][:120]}")
314
+
315
+ return chrf.score
src/model.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 10+2 Tied Transformer for English β†’ Malay Translation
3
+ =======================================================
4
+ An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
5
+
6
+ Architecture (redesigned for efficient T4 GPU training & inference):
7
+ d_model = 512 (embedding dimension, head_dim = 64)
8
+ n_head = 8 (attention heads)
9
+ encoder layers = 10 (deep encoder for source understanding)
10
+ decoder layers = 2 (shallow decoder for fast generation)
11
+ d_ff = 2048 (feed-forward inner dimension)
12
+ dropout = 0.1
13
+ norm_first = True (pre-norm for training stability)
14
+ shared embeddings = True (single vocab, en+ms share Latin script)
15
+ tied output proj. = True (output reuses embedding weights)
16
+
17
+ Key design choices (see architecture_report.md for full rationale):
18
+ β€’ **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
19
+ translation quality; decoder depth can be aggressively reduced
20
+ with minimal quality loss and ~3Γ— faster inference.
21
+ β€’ **Shared vocabulary:** English and Malay both use Latin script with
22
+ significant lexical overlap (loanwords, numbers, proper nouns).
23
+ A joint BPE naturally captures cross-lingual subword patterns.
24
+ β€’ **Tied output projection (Press & Wolf, 2017):** The decoder's output
25
+ linear layer reuses the shared embedding matrix, saving ~26M params
26
+ and acting as a regulariser.
27
+ β€’ **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
28
+ training of a 10-layer encoder. Places LayerNorm before each sublayer.
29
+ β€’ Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
30
+ SDPA fused kernels active (PyTorch 2.0+).
31
+ """
32
+
33
+ from __future__ import annotations
34
+
35
+ import math
36
+ from typing import Optional
37
+
38
+ import torch
39
+ import torch.nn as nn
40
+
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Positional Encoding (sinusoidal, from "Attention Is All You Need")
44
+ # ---------------------------------------------------------------------------
45
+ class PositionalEncoding(nn.Module):
46
+ """
47
+ Inject positional information via fixed sinusoidal signals.
48
+
49
+ PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
50
+ PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
51
+ """
52
+
53
+ def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
54
+ super().__init__()
55
+ self.dropout = nn.Dropout(p=dropout)
56
+
57
+ pe = torch.zeros(max_len, d_model) # (max_len, d_model)
58
+ position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
59
+ div_term = torch.exp(
60
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
61
+ ) # (d_model/2,)
62
+
63
+ pe[:, 0::2] = torch.sin(position * div_term)
64
+ pe[:, 1::2] = torch.cos(position * div_term)
65
+ pe = pe.unsqueeze(0) # (1, max_len, d_model)
66
+ self.register_buffer("pe", pe)
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Args:
71
+ x: (batch, seq_len, d_model)
72
+ Returns:
73
+ (batch, seq_len, d_model) with positional encoding added.
74
+ """
75
+ x = x + self.pe[:, : x.size(1)]
76
+ return self.dropout(x)
77
+
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # Full Transformer Model (10+2 Tied)
81
+ # ---------------------------------------------------------------------------
82
+ class TransformerTranslator(nn.Module):
83
+ """
84
+ Asymmetric encoder-decoder Transformer with shared/tied embeddings.
85
+
86
+ Parameters
87
+ ----------
88
+ vocab_size : int
89
+ Size of the shared source+target vocabulary.
90
+ d_model : int
91
+ Embedding / hidden dimension.
92
+ n_head : int
93
+ Number of attention heads.
94
+ num_encoder_layers : int
95
+ Number of encoder blocks (default 10).
96
+ num_decoder_layers : int
97
+ Number of decoder blocks (default 2).
98
+ d_ff : int
99
+ Feed-forward inner dimension.
100
+ dropout : float
101
+ Dropout rate.
102
+ max_len : int
103
+ Maximum sequence length for positional encoding.
104
+ pad_idx : int
105
+ Padding token ID (used to create padding masks).
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ vocab_size: int,
111
+ d_model: int = 512,
112
+ n_head: int = 8,
113
+ num_encoder_layers: int = 10,
114
+ num_decoder_layers: int = 2,
115
+ d_ff: int = 2048,
116
+ dropout: float = 0.1,
117
+ max_len: int = 512,
118
+ pad_idx: int = 0,
119
+ ):
120
+ super().__init__()
121
+ self.pad_idx = pad_idx
122
+ self.d_model = d_model
123
+
124
+ # --- Shared embedding (one matrix for both enc & dec) -------------
125
+ self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
126
+ self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
127
+ self.embed_scale = math.sqrt(d_model)
128
+
129
+ # --- Core Transformer (asymmetric, pre-norm) ----------------------
130
+ self.transformer = nn.Transformer(
131
+ d_model=d_model,
132
+ nhead=n_head,
133
+ num_encoder_layers=num_encoder_layers,
134
+ num_decoder_layers=num_decoder_layers,
135
+ dim_feedforward=d_ff,
136
+ dropout=dropout,
137
+ batch_first=True,
138
+ norm_first=True, # pre-layer norm for stability
139
+ )
140
+
141
+ # --- Tied output projection (reuses embedding weights) ------------
142
+ # No separate nn.Linear β€” forward() uses F.linear with shared weights
143
+ self.output_bias = nn.Parameter(torch.zeros(vocab_size))
144
+
145
+ # --- Initialize weights -------------------------------------------
146
+ self._init_weights()
147
+
148
+ def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
149
+ """Shared embedding + scale + positional encoding."""
150
+ return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
151
+
152
+ def _init_weights(self):
153
+ """Xavier-uniform initialization for embeddings."""
154
+ nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
155
+ # Zero out padding embedding
156
+ with torch.no_grad():
157
+ self.shared_embedding.weight[self.pad_idx].zero_()
158
+
159
+ # ------------------------------------------------------------------
160
+ # Mask utilities
161
+ # ------------------------------------------------------------------
162
+ @staticmethod
163
+ def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
164
+ """
165
+ Causal mask for the decoder: prevents attending to future positions.
166
+ Returns a (sz, sz) boolean mask where True = blocked.
167
+ """
168
+ return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
169
+
170
+ def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
171
+ """
172
+ Create a padding mask: True where token == pad_idx.
173
+ Shape: (batch, seq_len)
174
+ """
175
+ return x == self.pad_idx
176
+
177
+ # ------------------------------------------------------------------
178
+ # Forward
179
+ # ------------------------------------------------------------------
180
+ def forward(
181
+ self,
182
+ src: torch.Tensor,
183
+ tgt: torch.Tensor,
184
+ src_key_padding_mask: Optional[torch.Tensor] = None,
185
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
186
+ ) -> torch.Tensor:
187
+ """
188
+ Args:
189
+ src: (batch, src_len) source token IDs.
190
+ tgt: (batch, tgt_len) target token IDs (teacher-forced).
191
+
192
+ Returns:
193
+ logits: (batch, tgt_len, vocab_size)
194
+ """
195
+ # Build masks if not provided
196
+ if src_key_padding_mask is None:
197
+ src_key_padding_mask = self._make_pad_mask(src)
198
+ if tgt_key_padding_mask is None:
199
+ tgt_key_padding_mask = self._make_pad_mask(tgt)
200
+
201
+ # Causal mask for decoder
202
+ tgt_len = tgt.size(1)
203
+ tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
204
+
205
+ # Shared embeddings for both encoder and decoder
206
+ src_emb = self._embed(src)
207
+ tgt_emb = self._embed(tgt)
208
+
209
+ # Transformer forward
210
+ out = self.transformer(
211
+ src=src_emb,
212
+ tgt=tgt_emb,
213
+ tgt_mask=tgt_mask,
214
+ src_key_padding_mask=src_key_padding_mask,
215
+ tgt_key_padding_mask=tgt_key_padding_mask,
216
+ memory_key_padding_mask=src_key_padding_mask,
217
+ ) # (batch, tgt_len, d_model)
218
+
219
+ # Tied output projection: logits = out @ embedding_weights.T + bias
220
+ logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
221
+ return logits
222
+
223
+ # ------------------------------------------------------------------
224
+ # Inference helpers
225
+ # ------------------------------------------------------------------
226
+ def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
227
+ """Run only the encoder. Returns memory: (batch, src_len, d_model)."""
228
+ if src_key_padding_mask is None:
229
+ src_key_padding_mask = self._make_pad_mask(src)
230
+ src_emb = self._embed(src)
231
+ return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
232
+
233
+ def decode(
234
+ self,
235
+ tgt: torch.Tensor,
236
+ memory: torch.Tensor,
237
+ tgt_key_padding_mask: Optional[torch.Tensor] = None,
238
+ memory_key_padding_mask: Optional[torch.Tensor] = None,
239
+ ) -> torch.Tensor:
240
+ """Run only the decoder given encoder memory. Returns logits."""
241
+ if tgt_key_padding_mask is None:
242
+ tgt_key_padding_mask = self._make_pad_mask(tgt)
243
+ tgt_len = tgt.size(1)
244
+ tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
245
+ tgt_emb = self._embed(tgt)
246
+ out = self.transformer.decoder(
247
+ tgt_emb,
248
+ memory,
249
+ tgt_mask=tgt_mask,
250
+ tgt_key_padding_mask=tgt_key_padding_mask,
251
+ memory_key_padding_mask=memory_key_padding_mask,
252
+ )
253
+ return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
254
+
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # Helper: count parameters
258
+ # ---------------------------------------------------------------------------
259
+ def count_parameters(model: nn.Module) -> int:
260
+ """Return the number of trainable parameters."""
261
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
262
+
263
+
264
+ # ---------------------------------------------------------------------------
265
+ # Helper: build model
266
+ # ---------------------------------------------------------------------------
267
+ def build_model(
268
+ vocab_size: int,
269
+ pad_idx: int = 0,
270
+ device: Optional[torch.device] = None,
271
+ **kwargs,
272
+ ) -> TransformerTranslator:
273
+ """
274
+ Build and return a TransformerTranslator with default hyperparameters.
275
+
276
+ Any kwarg (d_model, n_head, etc.) overrides the default.
277
+ """
278
+ if device is None:
279
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280
+
281
+ model = TransformerTranslator(
282
+ vocab_size=vocab_size,
283
+ pad_idx=pad_idx,
284
+ **kwargs,
285
+ ).to(device)
286
+
287
+ return model
src/tokenizer.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Byte-Pair Encoding (BPE) Tokenizer for English-Malay Translation
3
+ =================================================================
4
+ We support two modes:
5
+ 1. **Shared tokenizer** (preferred for 10+2 Tied Transformer):
6
+ A single BPE tokenizer trained on the concatenated en+ms corpus.
7
+ Both encoder and decoder share the same vocabulary.
8
+ 2. **Separate tokenizers** (legacy):
9
+ Two independent BPE tokenizers, one per language.
10
+
11
+ Why BPE?
12
+ β€’ Handles subword units, so rare / unseen words are decomposed into
13
+ known subword pieces instead of mapping to [UNK].
14
+ β€’ Malay is morphologically rich (prefixes: me-, ber-, di-; suffixes:
15
+ -kan, -an, -i). BPE naturally learns these affixes as subword units,
16
+ giving much better coverage than a word-level tokenizer.
17
+ β€’ Keeps vocabulary compact while still reaching high coverage on both
18
+ English and Malay.
19
+
20
+ Why shared vocabulary for en-ms?
21
+ β€’ Both languages use the Latin script with significant lexical overlap
22
+ (loanwords: "teknologi", "matematik", "universiti"; numbers; proper nouns).
23
+ β€’ A joint BPE captures cross-lingual subword patterns and enables
24
+ tied embeddings in the model (Press & Wolf, 2017), saving ~26M params.
25
+
26
+ Design choices:
27
+ β€’ NFKC normalisation + lowercase – ensures consistent encoding of
28
+ Unicode characters and removes casing noise.
29
+ β€’ Whitespace pre-tokeniser – splits on spaces before BPE merges; simple
30
+ and effective for Latin-script languages.
31
+ β€’ Special tokens:
32
+ [PAD] – padding for uniform sequence lengths in batches
33
+ [UNK] – fallback for unknown characters
34
+ [CLS] – beginning-of-sequence / classification token
35
+ [SEP] – separator (unused in basic seq2seq but reserved)
36
+ [MASK] – reserved for masked-LM pretraining objectives
37
+ [BOS] – beginning of sentence (fed to decoder at step 0)
38
+ [EOS] – end of sentence (signals the decoder to stop)
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ import os
44
+ import tempfile
45
+ from pathlib import Path
46
+ from typing import Iterator, List, Optional, Union
47
+
48
+ from tokenizers import Tokenizer
49
+ from tokenizers.models import BPE
50
+ from tokenizers.trainers import BpeTrainer
51
+ from tokenizers.pre_tokenizers import Whitespace
52
+ from tokenizers.normalizers import Sequence, NFKC, Lowercase
53
+ from tokenizers.processors import TemplateProcessing
54
+
55
+ # ---------------------------------------------------------------------------
56
+ # Constants
57
+ # ---------------------------------------------------------------------------
58
+ SPECIAL_TOKENS: List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"]
59
+ PAD_TOKEN = "[PAD]"
60
+ UNK_TOKEN = "[UNK]"
61
+ CLS_TOKEN = "[CLS]"
62
+ SEP_TOKEN = "[SEP]"
63
+ MASK_TOKEN = "[MASK]"
64
+ BOS_TOKEN = "[BOS]"
65
+ EOS_TOKEN = "[EOS]"
66
+
67
+ DEFAULT_VOCAB_SIZE = 50_000
68
+ DEFAULT_MIN_FREQUENCY = 2
69
+
70
+
71
+ # ---------------------------------------------------------------------------
72
+ # Helper: write an iterator of strings to a temporary file (needed by the
73
+ # HuggingFace `tokenizers` training API which expects file paths).
74
+ # ---------------------------------------------------------------------------
75
+ def _write_texts_to_tmpfile(texts: Iterator[str]) -> str:
76
+ """Write an iterable of strings to a temp file, one per line. Returns path."""
77
+ tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8")
78
+ for line in texts:
79
+ line = line.strip()
80
+ if line:
81
+ tmp.write(line + "\n")
82
+ tmp.close()
83
+ return tmp.name
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Core: build & train a BPE tokenizer
88
+ # ---------------------------------------------------------------------------
89
+ def build_tokenizer(
90
+ vocab_size: int = DEFAULT_VOCAB_SIZE,
91
+ min_frequency: int = DEFAULT_MIN_FREQUENCY,
92
+ ) -> tuple[Tokenizer, BpeTrainer]:
93
+ """
94
+ Create an *untrained* BPE tokenizer and its trainer.
95
+
96
+ Returns
97
+ -------
98
+ tokenizer : Tokenizer
99
+ Ready to call ``tokenizer.train(files, trainer)``.
100
+ trainer : BpeTrainer
101
+ Configured trainer instance.
102
+ """
103
+ tokenizer = Tokenizer(BPE(unk_token=UNK_TOKEN))
104
+
105
+ # --- Normalisation: NFKC (canonical Unicode) + lowercase -------------
106
+ tokenizer.normalizer = Sequence([NFKC(), Lowercase()])
107
+
108
+ # --- Pre-tokenisation: split on whitespace ---------------------------
109
+ tokenizer.pre_tokenizer = Whitespace()
110
+
111
+ # --- Trainer ---------------------------------------------------------
112
+ trainer = BpeTrainer(
113
+ vocab_size=vocab_size,
114
+ min_frequency=min_frequency,
115
+ special_tokens=SPECIAL_TOKENS,
116
+ show_progress=True,
117
+ )
118
+
119
+ return tokenizer, trainer
120
+
121
+
122
+ def train_tokenizer(
123
+ texts: Union[List[str], Iterator[str]],
124
+ vocab_size: int = DEFAULT_VOCAB_SIZE,
125
+ min_frequency: int = DEFAULT_MIN_FREQUENCY,
126
+ files: Optional[List[str]] = None,
127
+ ) -> Tokenizer:
128
+ """
129
+ Train a BPE tokenizer on the given texts **or** files.
130
+
131
+ Parameters
132
+ ----------
133
+ texts : list[str] or iterator of str, optional
134
+ Raw sentences. Ignored when *files* is provided.
135
+ vocab_size : int
136
+ Target vocabulary size (default 30 000).
137
+ min_frequency : int
138
+ Minimum frequency for a pair to be merged.
139
+ files : list[str], optional
140
+ Paths to plain-text files (one sentence per line).
141
+
142
+ Returns
143
+ -------
144
+ Tokenizer
145
+ Trained tokenizer ready for encoding / decoding.
146
+ """
147
+ tokenizer, trainer = build_tokenizer(vocab_size, min_frequency)
148
+
149
+ if files is not None:
150
+ tokenizer.train(files, trainer)
151
+ else:
152
+ # Write texts to a temporary file so we can use the fast Rust trainer
153
+ tmp_path = _write_texts_to_tmpfile(iter(texts))
154
+ try:
155
+ tokenizer.train([tmp_path], trainer)
156
+ finally:
157
+ os.remove(tmp_path)
158
+
159
+ # --- Post-processing: wrap every encoded sequence with [BOS] … [EOS] -
160
+ bos_id = tokenizer.token_to_id(BOS_TOKEN)
161
+ eos_id = tokenizer.token_to_id(EOS_TOKEN)
162
+ tokenizer.post_processor = TemplateProcessing(
163
+ single=f"[BOS]:0 $A:0 [EOS]:0",
164
+ pair=f"[BOS]:0 $A:0 [EOS]:0 [BOS]:1 $B:1 [EOS]:1",
165
+ special_tokens=[
166
+ ("[BOS]", bos_id),
167
+ ("[EOS]", eos_id),
168
+ ],
169
+ )
170
+
171
+ return tokenizer
172
+
173
+
174
+ # ---------------------------------------------------------------------------
175
+ # Convenience wrappers for saving / loading
176
+ # ---------------------------------------------------------------------------
177
+ def save_tokenizer(tokenizer: Tokenizer, path: Union[str, Path]) -> None:
178
+ """Save a trained tokenizer to a JSON file."""
179
+ path = Path(path)
180
+ path.parent.mkdir(parents=True, exist_ok=True)
181
+ tokenizer.save(str(path))
182
+ print(f"[βœ“] Tokenizer saved β†’ {path}")
183
+
184
+
185
+ def load_tokenizer(path: Union[str, Path]) -> Tokenizer:
186
+ """Load a previously saved tokenizer from a JSON file."""
187
+ tokenizer = Tokenizer.from_file(str(path))
188
+ print(f"[βœ“] Tokenizer loaded ← {path}")
189
+ return tokenizer
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # Encoding / decoding helpers
194
+ # ---------------------------------------------------------------------------
195
+ def encode(tokenizer: Tokenizer, text: str) -> List[int]:
196
+ """Encode a single string and return token IDs (includes [BOS]/[EOS])."""
197
+ return tokenizer.encode(text).ids
198
+
199
+
200
+ def decode(tokenizer: Tokenizer, ids: List[int]) -> str:
201
+ """Decode token IDs back to a string, skipping special tokens."""
202
+ return tokenizer.decode(ids, skip_special_tokens=True)
203
+
204
+
205
+ def get_vocab_size(tokenizer: Tokenizer) -> int:
206
+ """Return the size of the tokenizer's vocabulary."""
207
+ return tokenizer.get_vocab_size()
208
+
209
+
210
+ def token_to_id(tokenizer: Tokenizer, token: str) -> Optional[int]:
211
+ """Look up the integer ID for a single token string."""
212
+ return tokenizer.token_to_id(token)
213
+
214
+
215
+ def id_to_token(tokenizer: Tokenizer, id: int) -> Optional[str]:
216
+ """Look up the token string for a single integer ID."""
217
+ return tokenizer.id_to_token(id)
218
+
219
+
220
+ # ---------------------------------------------------------------------------
221
+ # High-level: train a SHARED tokenizer on both languages (for tied embeddings)
222
+ # ---------------------------------------------------------------------------
223
+ def train_shared_tokenizer_from_dataset(
224
+ dataset,
225
+ src_lang: str = "en",
226
+ tgt_lang: str = "ms",
227
+ vocab_size: int = DEFAULT_VOCAB_SIZE,
228
+ save_dir: Union[str, Path] = "tokenizer",
229
+ ) -> Tokenizer:
230
+ """
231
+ Train a single shared BPE tokenizer on the concatenated en+ms corpus.
232
+
233
+ This is used with the 10+2 Tied Transformer architecture, where both
234
+ encoder and decoder share the same vocabulary and embedding matrix.
235
+
236
+ Parameters
237
+ ----------
238
+ dataset : datasets.Dataset
239
+ A HuggingFace dataset split where each example has a ``'translation'``
240
+ dict with keys for each language code.
241
+ src_lang : str
242
+ Source language code (default ``'en'``).
243
+ tgt_lang : str
244
+ Target language code (default ``'ms'``).
245
+ vocab_size : int
246
+ Vocabulary size for the shared tokenizer.
247
+ save_dir : str or Path
248
+ Directory to save the trained tokenizer JSON file.
249
+
250
+ Returns
251
+ -------
252
+ Tokenizer
253
+ A single shared tokenizer for both languages.
254
+ """
255
+ save_dir = Path(save_dir)
256
+
257
+ # Concatenate all source and target sentences into one corpus
258
+ src_texts = [example["translation"][src_lang] for example in dataset]
259
+ tgt_texts = [example["translation"][tgt_lang] for example in dataset]
260
+ all_texts = src_texts + tgt_texts
261
+
262
+ print(f"Training shared BPE tokenizer on {len(all_texts):,} sentences "
263
+ f"({len(src_texts):,} {src_lang} + {len(tgt_texts):,} {tgt_lang}) …")
264
+ shared_tokenizer = train_tokenizer(all_texts, vocab_size=vocab_size)
265
+ save_tokenizer(shared_tokenizer, save_dir / "tokenizer_shared.json")
266
+
267
+ # Sanity check
268
+ for name, sample in [(src_lang, src_texts[0]), (tgt_lang, tgt_texts[0])]:
269
+ enc = shared_tokenizer.encode(sample)
270
+ print(f"\n[{name}] Sample: {sample[:80]}…")
271
+ print(f" Tokens : {enc.tokens[:15]}…")
272
+ print(f" IDs : {enc.ids[:15]}…")
273
+ print(f" Decoded: {shared_tokenizer.decode(enc.ids, skip_special_tokens=True)[:80]}…")
274
+
275
+ print(f"\n[βœ“] Shared tokenizer trained and saved to {save_dir}/tokenizer_shared.json")
276
+ return shared_tokenizer
277
+
278
+
279
+ # ---------------------------------------------------------------------------
280
+ # High-level: train source (English) & target (Malay) tokenizers from a
281
+ # HuggingFace dataset split.
282
+ # ---------------------------------------------------------------------------
283
+ def train_tokenizers_from_dataset(
284
+ dataset,
285
+ src_lang: str = "en",
286
+ tgt_lang: str = "ms",
287
+ vocab_size: int = DEFAULT_VOCAB_SIZE,
288
+ save_dir: Union[str, Path] = "tokenizer",
289
+ ) -> tuple[Tokenizer, Tokenizer]:
290
+ """
291
+ Train separate BPE tokenizers for source and target languages.
292
+
293
+ Parameters
294
+ ----------
295
+ dataset : datasets.Dataset
296
+ A HuggingFace dataset split (e.g. ``dataset['train']``) where each
297
+ example has a ``'translation'`` dict with keys for each language code.
298
+ src_lang : str
299
+ Source language code (default ``'en'``).
300
+ tgt_lang : str
301
+ Target language code (default ``'ms'``).
302
+ vocab_size : int
303
+ Vocabulary size for each tokenizer.
304
+ save_dir : str or Path
305
+ Directory to save the trained tokenizer JSON files.
306
+
307
+ Returns
308
+ -------
309
+ (src_tokenizer, tgt_tokenizer)
310
+ """
311
+ save_dir = Path(save_dir)
312
+
313
+ # Extract raw sentences from the dataset
314
+ src_texts = [example["translation"][src_lang] for example in dataset]
315
+ tgt_texts = [example["translation"][tgt_lang] for example in dataset]
316
+
317
+ print(f"Training source tokenizer ({src_lang}) on {len(src_texts):,} sentences …")
318
+ src_tokenizer = train_tokenizer(src_texts, vocab_size=vocab_size)
319
+ save_tokenizer(src_tokenizer, save_dir / f"tokenizer_{src_lang}.json")
320
+
321
+ print(f"Training target tokenizer ({tgt_lang}) on {len(tgt_texts):,} sentences …")
322
+ tgt_tokenizer = train_tokenizer(tgt_texts, vocab_size=vocab_size)
323
+ save_tokenizer(tgt_tokenizer, save_dir / f"tokenizer_{tgt_lang}.json")
324
+
325
+ # Quick sanity check
326
+ for name, tok, sample in [
327
+ (src_lang, src_tokenizer, src_texts[0]),
328
+ (tgt_lang, tgt_tokenizer, tgt_texts[0]),
329
+ ]:
330
+ enc = tok.encode(sample)
331
+ print(f"\n[{name}] Sample: {sample[:80]}…")
332
+ print(f" Tokens : {enc.tokens[:15]}…")
333
+ print(f" IDs : {enc.ids[:15]}…")
334
+ print(f" Decoded: {tok.decode(enc.ids, skip_special_tokens=True)[:80]}…")
335
+
336
+ print(f"\n[βœ“] Both tokenizers trained and saved to {save_dir}/")
337
+ return src_tokenizer, tgt_tokenizer
338
+
339
+
340
+ # ---------------------------------------------------------------------------
341
+ # Standalone usage
342
+ # ---------------------------------------------------------------------------
343
+ if __name__ == "__main__":
344
+ from datasets import load_from_disk
345
+
346
+ print("Loading TED Talks IWSLT dataset (en ↔ ms, 2016) …")
347
+ ds = load_from_disk("dataset/en_ms_2016")
348
+
349
+ src_tok, tgt_tok = train_tokenizers_from_dataset(
350
+ ds,
351
+ src_lang="en",
352
+ tgt_lang="ms",
353
+ vocab_size=DEFAULT_VOCAB_SIZE,
354
+ save_dir="tokenizer",
355
+ )
356
+
357
+ print(f"\nEnglish vocab size : {get_vocab_size(src_tok):,}")
358
+ print(f"Malay vocab size : {get_vocab_size(tgt_tok):,}")
359
+ print(f"[PAD] id (en) : {token_to_id(src_tok, PAD_TOKEN)}")
360
+ print(f"[EOS] id (ms) : {token_to_id(tgt_tok, EOS_TOKEN)}")
src/training.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training loop for the Transformer translator.
3
+ ===============================================
4
+ Provides:
5
+ β€’ ``TranslationDataset`` – a PyTorch Dataset that tokenises and pads
6
+ source/target sentence pairs.
7
+ β€’ ``create_dataloaders`` – builds train / validation DataLoaders with
8
+ an 90/10 split.
9
+ β€’ ``train_one_epoch`` – one full pass over the training set.
10
+ β€’ ``evaluate_loss`` – average loss on the validation set.
11
+ β€’ ``train`` – full training driver with logging, LR
12
+ scheduling, checkpointing, and early stopping.
13
+
14
+ Design choices:
15
+ β€’ Label-smoothed cross-entropy (smoothing = 0.1) for better
16
+ generalisation.
17
+ β€’ AdamW with a linear-warmup + cosine-decay schedule (stable for
18
+ small datasets).
19
+ β€’ Mixed-precision (AMP) with ``torch.amp`` for speed / memory on T4.
20
+ β€’ Gradient clipping at max_norm = 1.0 to avoid exploding gradients.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import math
26
+ import os
27
+ import time
28
+ from dataclasses import dataclass, field
29
+ from pathlib import Path
30
+ from typing import List, Optional, Tuple
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+ from torch.utils.data import Dataset, DataLoader, random_split
35
+ from tokenizers import Tokenizer
36
+
37
+
38
+ # ──────────────────────────────────────────────────────────────────────
39
+ # 1. Translation Dataset
40
+ # ──────────────────────────────────────────────────────────────────────
41
+ class TranslationDataset(Dataset):
42
+ """
43
+ Wraps a HuggingFace dataset of translation pairs into a PyTorch
44
+ Dataset that returns padded token-ID tensors.
45
+
46
+ Each ``__getitem__`` returns::
47
+ {
48
+ "src": LongTensor[max_len], # source token IDs (padded)
49
+ "tgt": LongTensor[max_len], # target input (with [BOS], no final [EOS])
50
+ "label": LongTensor[max_len], # target labels (no [BOS], with [EOS])
51
+ }
52
+
53
+ The *tgt* / *label* split implements **teacher forcing**: the decoder
54
+ receives ``[BOS] w1 w2 …`` and must predict ``w1 w2 … [EOS]``.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ hf_dataset,
60
+ src_tokenizer: Tokenizer,
61
+ tgt_tokenizer: Tokenizer,
62
+ src_lang: str = "en",
63
+ tgt_lang: str = "ms",
64
+ max_len: int = 128,
65
+ pad_id: int = 0,
66
+ ):
67
+ self.data = hf_dataset
68
+ self.src_tok = src_tokenizer
69
+ self.tgt_tok = tgt_tokenizer
70
+ self.src_lang = src_lang
71
+ self.tgt_lang = tgt_lang
72
+ self.max_len = max_len
73
+ self.pad_id = pad_id
74
+
75
+ def __len__(self) -> int:
76
+ return len(self.data)
77
+
78
+ def _pad(self, ids: List[int]) -> List[int]:
79
+ """Truncate to max_len, then right-pad with pad_id."""
80
+ ids = ids[: self.max_len]
81
+ return ids + [self.pad_id] * (self.max_len - len(ids))
82
+
83
+ def __getitem__(self, idx: int) -> dict:
84
+ pair = self.data[idx]["translation"]
85
+
86
+ # Encode (includes [BOS] … [EOS] from post-processor)
87
+ src_ids = self.src_tok.encode(pair[self.src_lang]).ids
88
+ tgt_ids = self.tgt_tok.encode(pair[self.tgt_lang]).ids
89
+
90
+ # Teacher-forcing split:
91
+ # tgt_input = [BOS] w1 w2 … wN (drop last token)
92
+ # tgt_label = w1 w2 … wN [EOS] (drop first token)
93
+ tgt_input = tgt_ids[:-1]
94
+ tgt_label = tgt_ids[1:]
95
+
96
+ return {
97
+ "src": torch.tensor(self._pad(src_ids), dtype=torch.long),
98
+ "tgt": torch.tensor(self._pad(tgt_input), dtype=torch.long),
99
+ "label": torch.tensor(self._pad(tgt_label), dtype=torch.long),
100
+ }
101
+
102
+
103
+ # ──────────────────────────────────────────────────────────────────────
104
+ # 2. DataLoader factory
105
+ # ──────────────────────────────────────────────────────────────────────
106
+ def create_dataloaders(
107
+ hf_dataset,
108
+ src_tokenizer: Tokenizer,
109
+ tgt_tokenizer: Tokenizer,
110
+ src_lang: str = "en",
111
+ tgt_lang: str = "ms",
112
+ max_len: int = 128,
113
+ batch_size: int = 32,
114
+ val_ratio: float = 0.1,
115
+ pad_id: int = 0,
116
+ seed: int = 42,
117
+ ) -> Tuple[DataLoader, DataLoader, TranslationDataset]:
118
+ """
119
+ Build training and validation DataLoaders from a HuggingFace dataset.
120
+
121
+ Returns
122
+ -------
123
+ train_loader, val_loader, full_dataset
124
+ """
125
+ full_ds = TranslationDataset(
126
+ hf_dataset, src_tokenizer, tgt_tokenizer,
127
+ src_lang, tgt_lang, max_len, pad_id,
128
+ )
129
+
130
+ val_size = max(1, int(len(full_ds) * val_ratio))
131
+ train_size = len(full_ds) - val_size
132
+
133
+ generator = torch.Generator().manual_seed(seed)
134
+ train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=generator)
135
+
136
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
137
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
138
+
139
+ print(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}")
140
+ return train_loader, val_loader, full_ds
141
+
142
+
143
+ # ──────────────────────────────────────────────────────────────────────
144
+ # 3. Training configuration dataclass
145
+ # ──────────────────────────────────────────────────────────────────────
146
+ @dataclass
147
+ class TrainConfig:
148
+ """All tuneable knobs in one place."""
149
+ epochs: int = 50
150
+ batch_size: int = 32
151
+ max_len: int = 128
152
+ lr: float = 5e-4
153
+ warmup_steps: int = 200
154
+ label_smoothing: float = 0.1
155
+ grad_clip: float = 1.0
156
+ use_amp: bool = True
157
+ val_ratio: float = 0.1
158
+ checkpoint_dir: str = "training/checkpoints"
159
+ log_every: int = 10 # print loss every N steps
160
+ patience: int = 10 # early-stopping patience (epochs)
161
+ seed: int = 42
162
+
163
+
164
+ # ──────────────────────────────────────────────────────────────────────
165
+ # 4. LR scheduler with linear warmup + cosine decay
166
+ # ──────────────────────────────────────────────────────────────────────
167
+ def _build_scheduler(optimizer, warmup_steps: int, total_steps: int):
168
+ """Linear warmup for `warmup_steps`, then cosine decay to 0."""
169
+
170
+ def lr_lambda(step):
171
+ if step < warmup_steps:
172
+ return step / max(1, warmup_steps)
173
+ progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
174
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
175
+
176
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
177
+
178
+
179
+ # ──────────────────────────────────────────────────────────────────────
180
+ # 5. Single-epoch training
181
+ # ──────────────────────────────────────────────────────────────────────
182
+ def train_one_epoch(
183
+ model: nn.Module,
184
+ loader: DataLoader,
185
+ optimizer: torch.optim.Optimizer,
186
+ scheduler,
187
+ criterion: nn.Module,
188
+ device: torch.device,
189
+ scaler: Optional[torch.amp.GradScaler],
190
+ grad_clip: float = 1.0,
191
+ log_every: int = 10,
192
+ epoch: int = 0,
193
+ ) -> float:
194
+ """Train for one epoch. Returns average loss."""
195
+ model.train()
196
+ total_loss = 0.0
197
+ n_tokens = 0
198
+
199
+ for step, batch in enumerate(loader):
200
+ src = batch["src"].to(device)
201
+ tgt = batch["tgt"].to(device)
202
+ label = batch["label"].to(device)
203
+
204
+ optimizer.zero_grad()
205
+
206
+ amp_enabled = scaler is not None
207
+ with torch.amp.autocast("cuda", enabled=amp_enabled):
208
+ logits = model(src, tgt) # (B, T, V)
209
+ loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
210
+
211
+ if scaler is not None:
212
+ scaler.scale(loss).backward()
213
+ scaler.unscale_(optimizer)
214
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
215
+ scaler.step(optimizer)
216
+ scaler.update()
217
+ else:
218
+ loss.backward()
219
+ nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
220
+ optimizer.step()
221
+
222
+ scheduler.step()
223
+
224
+ # Accumulate loss (ignore padding contribution)
225
+ non_pad = (label != model.pad_idx).sum().item()
226
+ total_loss += loss.item() * non_pad
227
+ n_tokens += non_pad
228
+
229
+ if (step + 1) % log_every == 0:
230
+ avg = total_loss / max(n_tokens, 1)
231
+ lr = scheduler.get_last_lr()[0]
232
+ print(f" Epoch {epoch+1} | Step {step+1}/{len(loader)} | Loss {avg:.4f} | LR {lr:.2e}")
233
+
234
+ return total_loss / max(n_tokens, 1)
235
+
236
+
237
+ # ──────────────────────────────────────────────────────────────────────
238
+ # 6. Validation loss
239
+ # ──────────────────────────────────────────────────────────────────────
240
+ @torch.no_grad()
241
+ def evaluate_loss(
242
+ model: nn.Module,
243
+ loader: DataLoader,
244
+ criterion: nn.Module,
245
+ device: torch.device,
246
+ use_amp: bool = False,
247
+ ) -> float:
248
+ """Compute average loss over a validation set (with AMP to match training)."""
249
+ model.eval()
250
+ total_loss = 0.0
251
+ n_tokens = 0
252
+ n_batches = len(loader)
253
+
254
+ for step, batch in enumerate(loader):
255
+ src = batch["src"].to(device)
256
+ tgt = batch["tgt"].to(device)
257
+ label = batch["label"].to(device)
258
+
259
+ with torch.amp.autocast("cuda", enabled=use_amp):
260
+ logits = model(src, tgt)
261
+ loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
262
+
263
+ non_pad = (label != model.pad_idx).sum().item()
264
+ total_loss += loss.item() * non_pad
265
+ n_tokens += non_pad
266
+
267
+ if (step + 1) % max(1, n_batches // 4) == 0 or (step + 1) == n_batches:
268
+ print(f" Val {step+1}/{n_batches}", end="\r")
269
+
270
+ return total_loss / max(n_tokens, 1)
271
+
272
+
273
+ # ──────────────────────────────────────────────────────────────────────
274
+ # 7. Full training driver
275
+ # ──────────────────────────────────────────────────────────────────────
276
+ def train(
277
+ model: nn.Module,
278
+ train_loader: DataLoader,
279
+ val_loader: DataLoader,
280
+ cfg: TrainConfig,
281
+ device: torch.device,
282
+ trial=None,
283
+ ) -> dict:
284
+ """
285
+ Full training loop with logging, checkpointing, and early stopping.
286
+
287
+ Parameters
288
+ ----------
289
+ trial : optuna.trial.Trial, optional
290
+ If provided, reports val_loss after each epoch for ASHA pruning.
291
+
292
+ Returns
293
+ -------
294
+ history : dict
295
+ ``{"train_loss": [...], "val_loss": [...], "lr": [...]}``
296
+ """
297
+ # --- Loss function (label-smoothed CE, ignoring PAD) ---------------
298
+ criterion = nn.CrossEntropyLoss(
299
+ ignore_index=model.pad_idx,
300
+ label_smoothing=cfg.label_smoothing,
301
+ )
302
+
303
+ # --- Optimiser ------------------------------------------------------
304
+ optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.98), eps=1e-9)
305
+
306
+ # --- LR schedule ---------------------------------------------------
307
+ total_steps = cfg.epochs * len(train_loader)
308
+ scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
309
+
310
+ # --- AMP scaler ----------------------------------------------------
311
+ scaler = torch.amp.GradScaler("cuda") if (cfg.use_amp and device.type == "cuda") else None
312
+
313
+ # --- Checkpoint dir ------------------------------------------------
314
+ ckpt_dir = Path(cfg.checkpoint_dir)
315
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
316
+
317
+ history: dict = {"train_loss": [], "val_loss": [], "lr": []}
318
+ best_val = float("inf")
319
+ patience_ctr = 0
320
+
321
+ print(f"\n{'='*60}")
322
+ print(f"Starting training: {cfg.epochs} epochs, lr={cfg.lr}, AMP={cfg.use_amp}")
323
+ print(f"{'='*60}\n")
324
+
325
+ for epoch in range(cfg.epochs):
326
+ t0 = time.time()
327
+
328
+ train_loss = train_one_epoch(
329
+ model, train_loader, optimizer, scheduler, criterion,
330
+ device, scaler, cfg.grad_clip, cfg.log_every, epoch,
331
+ )
332
+ use_amp = cfg.use_amp and device.type == "cuda"
333
+ val_loss = evaluate_loss(model, val_loader, criterion, device, use_amp=use_amp)
334
+ lr = scheduler.get_last_lr()[0]
335
+
336
+ elapsed = time.time() - t0
337
+ history["train_loss"].append(train_loss)
338
+ history["val_loss"].append(val_loss)
339
+ history["lr"].append(lr)
340
+
341
+ print(
342
+ f"Epoch {epoch+1}/{cfg.epochs} | "
343
+ f"Train {train_loss:.4f} | Val {val_loss:.4f} | "
344
+ f"LR {lr:.2e} | {elapsed:.1f}s"
345
+ )
346
+
347
+ # --- Optuna ASHA pruning (if trial provided) ------------------
348
+ if trial is not None:
349
+ import optuna
350
+ trial.report(val_loss, epoch)
351
+ if trial.should_prune():
352
+ print(f"\nβœ‚ Optuna pruned this trial at epoch {epoch+1}.")
353
+ raise optuna.TrialPruned()
354
+
355
+ # --- Checkpoint best model ------------------------------------
356
+ if val_loss < best_val:
357
+ best_val = val_loss
358
+ patience_ctr = 0
359
+ torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
360
+ print(f" ↳ New best val loss β€” checkpoint saved.")
361
+ else:
362
+ patience_ctr += 1
363
+ if patience_ctr >= cfg.patience:
364
+ print(f"\n⏹ Early stopping after {cfg.patience} epochs without improvement.")
365
+ break
366
+
367
+ # Load best checkpoint
368
+ model.load_state_dict(torch.load(ckpt_dir / "best_model.pt", map_location=device, weights_only=True))
369
+ print(f"\nβœ“ Training complete. Best val loss: {best_val:.4f}")
370
+ return history
tokenizer_shared_16k.json ADDED
The diff for this file is too large to render. See raw diff