Upload en-ms Transformer (6+2 Tied, 16K BPE, chrF 45.62)
Browse files- README.md +292 -0
- best_model.pt +3 -0
- config.json +36 -0
- src/eval.py +315 -0
- src/model.py +287 -0
- src/tokenizer.py +360 -0
- src/training.py +370 -0
- tokenizer_shared_16k.json +0 -0
README.md
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- ms
|
| 5 |
+
tags:
|
| 6 |
+
- translation
|
| 7 |
+
- transformer
|
| 8 |
+
- en-ms
|
| 9 |
+
- pytorch
|
| 10 |
+
- bpe
|
| 11 |
+
- encoder-decoder
|
| 12 |
+
- tied-embeddings
|
| 13 |
+
- deep-encoder-shallow-decoder
|
| 14 |
+
license: mit
|
| 15 |
+
datasets:
|
| 16 |
+
- open_subtitles
|
| 17 |
+
metrics:
|
| 18 |
+
- chrf
|
| 19 |
+
pipeline_tag: translation
|
| 20 |
+
model-index:
|
| 21 |
+
- name: en-ms-transformer-6+2-tied
|
| 22 |
+
results:
|
| 23 |
+
- task:
|
| 24 |
+
type: translation
|
| 25 |
+
name: Translation
|
| 26 |
+
dataset:
|
| 27 |
+
name: OpenSubtitles v2018 en-ms
|
| 28 |
+
type: open_subtitles
|
| 29 |
+
split: test
|
| 30 |
+
metrics:
|
| 31 |
+
- type: chrf
|
| 32 |
+
value: 45.62
|
| 33 |
+
name: chrF (greedy)
|
| 34 |
+
- type: chrf
|
| 35 |
+
value: 44.99
|
| 36 |
+
name: chrF (beam=5)
|
| 37 |
+
---
|
| 38 |
+
|
| 39 |
+
# English β Malay Transformer (6+2 Tied, 16K BPE)
|
| 40 |
+
|
| 41 |
+
A custom encoder-decoder Transformer for English-to-Malay translation, built entirely from scratch in PyTorch. This model was developed as part of **IT3103 Advanced Topics in AI β Assignment 2, 2025 Semester 2**.
|
| 42 |
+
|
| 43 |
+
The project encompasses the full NMT pipeline: dataset curation, tokenizer training, architecture design with ablation studies, training with mixed-precision, and evaluation β all without using any pretrained models or high-level frameworks like Fairseq or OpenNMT.
|
| 44 |
+
|
| 45 |
+
## Model Description
|
| 46 |
+
|
| 47 |
+
| Component | Details |
|
| 48 |
+
|---|---|
|
| 49 |
+
| **Architecture** | 6-layer encoder + 2-layer decoder, pre-norm Transformer |
|
| 50 |
+
| **d_model / n_head / d_ff** | 512 / 8 / 2048 |
|
| 51 |
+
| **Vocab** | 16,000 shared BPE (English + Malay, joint) |
|
| 52 |
+
| **Dropout** | 0.3 |
|
| 53 |
+
| **Parameters** | ~36.6M |
|
| 54 |
+
| **Tied embeddings** | Yes β encoder input, decoder input, and output projection share the same weight matrix (Press & Wolf, 2017) |
|
| 55 |
+
| **Normalisation** | Pre-norm (LayerNorm before attention/FFN, not after) |
|
| 56 |
+
|
| 57 |
+
### Design Decisions and Rationale
|
| 58 |
+
|
| 59 |
+
**Why 6+2 (Deep Encoder, Shallow Decoder)?**
|
| 60 |
+
|
| 61 |
+
The asymmetric 6+2 architecture is grounded in [Kasai et al. (2021)](https://arxiv.org/abs/2006.10369) β *"Deep Encoder, Shallow Decoder: Reevaluating the Speed-Quality Tradeoff in Machine Translation"*. The core insight is that encoder depth contributes more to translation quality (richer source representations), while the decoder can be kept shallow without significant degradation. This was empirically validated by our own **Ablation Sweep 1** (see below), which showed that encoder depths of 2, 4, 6, and 8 all produced similar chrF scores (22β25 range), indicating the model hits diminishing returns quickly. We chose 6 as a safe operating point.
|
| 62 |
+
|
| 63 |
+
The shallow 2-layer decoder provides a practical speed advantage: ~2Γ faster inference compared to a symmetric 6+6, since autoregressive decoding must run the decoder once per output token.
|
| 64 |
+
|
| 65 |
+
**Why 16K shared vocabulary?**
|
| 66 |
+
|
| 67 |
+
We initially trained with 50K vocabulary but found it too sparse for 500K training sentences β most tokens appeared very infrequently, leaving embeddings under-trained. Reducing to 16K shared BPE produced denser embeddings and led to a 2.5Γ speedup per epoch (7.8 min vs ~20 min estimated at 50K). English and Malay share the Latin script with substantial lexical overlap (loanwords like "teknologi", "universiti"; numbers; proper nouns), making a joint vocabulary highly effective.
|
| 68 |
+
|
| 69 |
+
**Why tied embeddings?**
|
| 70 |
+
|
| 71 |
+
With a shared source-target vocabulary, tying the encoder embedding, decoder embedding, and output projection matrix (Press & Wolf, 2017) reduces the parameter count by ~16M while acting as a strong regulariser. The model learns a single semantic space for both languages.
|
| 72 |
+
|
| 73 |
+
**Why dropout 0.3?**
|
| 74 |
+
|
| 75 |
+
490K training sentences is relatively small for a Transformer. Dropout 0.3 was chosen as aggressive regularisation to prevent overfitting. Training curves confirm this was appropriate β the gap between train loss (3.17) and val loss (3.21) remained small throughout training, with no signs of overfitting even at epoch 20.
|
| 76 |
+
|
| 77 |
+
## Training Data
|
| 78 |
+
|
| 79 |
+
- **Dataset:** [OpenSubtitles v2018](https://opus.nlpl.eu/OpenSubtitles-v2018.php) (English-Malay aligned parallel corpus)
|
| 80 |
+
- **Raw corpus size:** ~17.3M parallel sentence pairs
|
| 81 |
+
- **After filtering:** 500,000 pairs selected
|
| 82 |
+
- **Split:** 490,000 train / 5,000 validation / 5,000 test (all in-distribution)
|
| 83 |
+
|
| 84 |
+
### Data Preprocessing Pipeline
|
| 85 |
+
|
| 86 |
+
The raw OpenSubtitles corpus is notoriously noisy (subtitle artifacts, music symbols, HTML tags, near-duplicate lines). We applied the following quality filters:
|
| 87 |
+
|
| 88 |
+
1. **Length filter:** 3β80 words per side (removes fragments and overly long lines)
|
| 89 |
+
2. **Length ratio filter:** max(len_en, len_ms) / min(len_en, len_ms) β€ 3.0 (removes misaligned pairs)
|
| 90 |
+
3. **Character length filter:** 10β400 characters per side
|
| 91 |
+
4. **Junk pattern removal:** Regex filter for music symbols (βͺβ«), HTML tags, bracket-only lines (e.g. `[music playing]`), ellipsis-only lines, dash-only lines
|
| 92 |
+
5. **Deduplication:** Case-insensitive exact match on the English side
|
| 93 |
+
|
| 94 |
+
This pipeline retains ~500K high-quality pairs from the first ~2.7M lines scanned.
|
| 95 |
+
|
| 96 |
+
### Why OpenSubtitles over TED Talks?
|
| 97 |
+
|
| 98 |
+
We initially experimented with the [IWSLT TED Talks](https://huggingface.co/datasets/IWSLT/ted_talks_iwslt) dataset (~5K en-ms pairs) and achieved a chrF of only **6.76** β the dataset was far too small. We then moved to OpenSubtitles which provides orders of magnitude more data. Importantly, we evaluate on **in-distribution** OpenSubtitles test data rather than using TED Talks as an out-of-distribution test set, which would unfairly penalise the model for domain mismatch (conversational subtitles vs. formal TED lectures).
|
| 99 |
+
|
| 100 |
+
## Ablation Studies
|
| 101 |
+
|
| 102 |
+
We conducted two systematic ablation sweeps to guide architecture and data decisions. All sweeps used a 50K vocabulary baseline with 3 training epochs for efficiency.
|
| 103 |
+
|
| 104 |
+
### Sweep 1: Encoder Depth
|
| 105 |
+
|
| 106 |
+
Fixed: 50K vocab, 500K data, 2-layer decoder, 3 epochs.
|
| 107 |
+
|
| 108 |
+
| Encoder Layers | chrF (TED test) | Val Loss | Params |
|
| 109 |
+
|---|---|---|---|
|
| 110 |
+
| 2 | 24.42 | 3.92 | 48.5M |
|
| 111 |
+
| 4 | 22.37 | 3.84 | 61.5M |
|
| 112 |
+
| 6 | 24.65 | 3.80 | 74.6M |
|
| 113 |
+
| 8 | 22.91 | 3.76 | 87.6M |
|
| 114 |
+
|
| 115 |
+
**Finding:** Encoder depth has **flat returns** on downstream chrF despite steadily decreasing validation loss. This suggests the TED Talks OOD test set was the bottleneck (confirmed later), not model capacity. We selected 6 layers as the sweet spot β the lowest loss before severe diminishing returns, and well-supported by the Kasai et al. finding.
|
| 116 |
+
|
| 117 |
+
### Sweep 2: Training Data Size
|
| 118 |
+
|
| 119 |
+
Fixed: 50K vocab, 6+2 architecture, 3 epochs.
|
| 120 |
+
|
| 121 |
+
| Train Size | chrF (TED test) | Val Loss |
|
| 122 |
+
|---|---|---|
|
| 123 |
+
| 50K | 16.67 | 4.50 |
|
| 124 |
+
| 100K | 19.60 | 4.11 |
|
| 125 |
+
| 200K | 22.47 | 3.93 |
|
| 126 |
+
| 500K | 26.50 | 3.75 |
|
| 127 |
+
|
| 128 |
+
**Finding:** chrF scales **approximately linearly with log(data size)** β a ~3.3 chrF improvement per doubling. This confirmed that **data volume is the dominant factor** for translation quality at this scale, far more impactful than architectural changes. This motivated our final model to use the maximum feasible data (490K after filtering).
|
| 129 |
+
|
| 130 |
+
## Training Details
|
| 131 |
+
|
| 132 |
+
| Setting | Value |
|
| 133 |
+
|---|---|
|
| 134 |
+
| Optimizer | AdamW (lr=5e-4, Ξ²β=0.9, Ξ²β=0.98, Ξ΅=1e-9) |
|
| 135 |
+
| Schedule | Linear warmup (4,000 steps) β cosine decay to 0 |
|
| 136 |
+
| Batch size | 128 |
|
| 137 |
+
| Max sequence length | 128 tokens |
|
| 138 |
+
| Epochs | 20 (early stopping patience=3, did not trigger) |
|
| 139 |
+
| Label smoothing | 0.1 |
|
| 140 |
+
| Gradient clipping | max_norm=1.0 |
|
| 141 |
+
| AMP | fp16 mixed precision (PyTorch GradScaler) |
|
| 142 |
+
| Hardware | NVIDIA RTX 5070 Ti (16GB VRAM), CUDA 13.1 |
|
| 143 |
+
| Training time | **2.62 hours** (157 min, ~7.85 min/epoch) |
|
| 144 |
+
|
| 145 |
+
### Training Progression
|
| 146 |
+
|
| 147 |
+
| Epoch | Train Loss | Val Loss | LR |
|
| 148 |
+
|---|---|---|---|
|
| 149 |
+
| 1 | 5.4036 | 4.2485 | 6.5e-5 |
|
| 150 |
+
| 5 | 3.5888 | 3.4519 | 3.7e-4 |
|
| 151 |
+
| 10 | 3.3605 | 3.2986 | 3.7e-4 |
|
| 152 |
+
| 15 | 3.2268 | 3.2346 | 1.8e-4 |
|
| 153 |
+
| 20 | 3.1683 | 3.2110 | 4.9e-6 |
|
| 154 |
+
|
| 155 |
+
The model converged smoothly with no overfitting β the train-val gap remained under 0.05 throughout training. The cosine LR decay drove the final epochs to squeeze out the last bits of improvement (val loss 3.23 at epoch 15 β 3.21 at epoch 20).
|
| 156 |
+
|
| 157 |
+
## Evaluation Results
|
| 158 |
+
|
| 159 |
+
Evaluated on **5,000 held-out in-distribution** OpenSubtitles test sentences with post-processing applied.
|
| 160 |
+
|
| 161 |
+
| Decoding Strategy | chrF |
|
| 162 |
+
|---|---|
|
| 163 |
+
| Greedy | **45.62** |
|
| 164 |
+
| Beam search (beam=5, length_penalty=0.6) | 44.99 |
|
| 165 |
+
|
| 166 |
+
### Post-Processing
|
| 167 |
+
|
| 168 |
+
The BPE tokenizer uses a `Whitespace` pre-tokenizer without continuation markers, so raw `decode()` output contains spurious spaces before punctuation (e.g., `"mendarat , tuan ."` instead of `"mendarat, tuan."`). We apply a lightweight regex-based post-processing step that:
|
| 169 |
+
|
| 170 |
+
1. Removes spaces before punctuation marks (`. , ? ! ; :`)
|
| 171 |
+
2. Removes spaces after opening brackets/quotes
|
| 172 |
+
3. Collapses spaced hyphens in compound words
|
| 173 |
+
4. Capitalises the first character
|
| 174 |
+
|
| 175 |
+
This post-processing improved chrF by **+1.05 points** (greedy: 44.57 β 45.62) β a free gain with zero retraining.
|
| 176 |
+
|
| 177 |
+
### Why Greedy > Beam Search?
|
| 178 |
+
|
| 179 |
+
Interestingly, greedy decoding outperforms beam search here. This is a known phenomenon in NMT: beam search with length penalty can produce outputs that are slightly too long or too short for chrF's character n-gram matching. Greedy decoding produces more "natural length" outputs that happen to align better with reference lengths in this corpus.
|
| 180 |
+
|
| 181 |
+
### Sample Translations
|
| 182 |
+
|
| 183 |
+
| # | English (Source) | Reference (Malay) | Model Output |
|
| 184 |
+
|---|---|---|---|
|
| 185 |
+
| 1 | Skywalker has just landed, lord. | Skywalker baru sahaja mendarat, tuan. | Skywalker baru mendarat, tuan. |
|
| 186 |
+
| 2 | Raymond, you like me? | Raymond, awak suka saya? | Raymond, awak suka saya? |
|
| 187 |
+
| 3 | She may be dying and it's all my fault. | Dia mungkin akan mati dan semuanya salah saya. | Dia mungkin akan mati dan semuanya salah saya. |
|
| 188 |
+
| 4 | He always remembers the cards. | Ia ingat kad. | Dia selalu ingat kad. |
|
| 189 |
+
| 5 | Hey, you wanna see something? | Hei, awak nak tengok sesuatu? | Hei, awak nak lihat sesuatu? |
|
| 190 |
+
| 6 | Why don't you just go talk to her? | Mengapa awak tidak bercakap dengannya? | Apa kata awak cakap dengan dia? |
|
| 191 |
+
| 7 | We still got that meat-lovers' pizza in the trunk. | Kita masih ada piza daging dalam but. | Kita masih ada piza daging di dalam but kereta. |
|
| 192 |
+
|
| 193 |
+
The model produces fluent, natural Malay that is often comparable or near-identical to the reference translations. Errors tend to occur on rare proper nouns (subword fragmentation) and idiomatic expressions.
|
| 194 |
+
|
| 195 |
+
## Tokenizer
|
| 196 |
+
|
| 197 |
+
- **Type:** Byte-Pair Encoding (BPE) via HuggingFace `tokenizers` library (Rust backend)
|
| 198 |
+
- **Vocab size:** 16,000 (shared joint vocabulary for both English and Malay)
|
| 199 |
+
- **Normalization:** NFKC Unicode normalisation + lowercase
|
| 200 |
+
- **Pre-tokenization:** Whitespace splitting
|
| 201 |
+
- **Post-processing:** `[BOS] $A [EOS]` template (auto-wraps encoded sequences)
|
| 202 |
+
- **Special tokens:** `[PAD]=0, [UNK]=1, [CLS]=2, [SEP]=3, [MASK]=4, [BOS]=5, [EOS]=6`
|
| 203 |
+
- **Trained on:** 490K training pairs only (980K sentences total) β no data leakage from val/test
|
| 204 |
+
|
| 205 |
+
### Why Shared BPE for en-ms?
|
| 206 |
+
|
| 207 |
+
English and Malay both use the Latin script with significant lexical overlap (loanwords: "teknologi", "matematik", "universiti"; numbers; proper nouns; punctuation). A joint BPE vocabulary captures cross-lingual subword patterns and directly enables tied embeddings. Malay's morphological affixes (me-, ber-, di-, -kan, -an, -i) are naturally learned as subword units by BPE, providing good coverage without an explicitly morphological tokenizer.
|
| 208 |
+
|
| 209 |
+
## Usage
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
import torch
|
| 213 |
+
from tokenizers import Tokenizer
|
| 214 |
+
|
| 215 |
+
# Load tokenizer
|
| 216 |
+
tokenizer = Tokenizer.from_file("tokenizer_shared_16k.json")
|
| 217 |
+
|
| 218 |
+
# Load model (requires model.py from src/)
|
| 219 |
+
from src.model import build_model
|
| 220 |
+
|
| 221 |
+
model = build_model(
|
| 222 |
+
vocab_size=16000, pad_idx=0, device=torch.device("cpu"),
|
| 223 |
+
d_model=512, n_head=8, num_encoder_layers=6, num_decoder_layers=2,
|
| 224 |
+
d_ff=2048, dropout=0.3, max_len=144,
|
| 225 |
+
)
|
| 226 |
+
model.load_state_dict(torch.load("best_model.pt", map_location="cpu", weights_only=True))
|
| 227 |
+
model.eval()
|
| 228 |
+
|
| 229 |
+
# Translate (requires eval.py from src/)
|
| 230 |
+
from src.eval import translate
|
| 231 |
+
result = translate(model, "Hello, how are you?", tokenizer, tokenizer,
|
| 232 |
+
bos_id=5, eos_id=6, pad_id=0, max_len=128,
|
| 233 |
+
device=torch.device("cpu"), beam_width=5)
|
| 234 |
+
print(result) # β "Hai, apa khabar?"
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
## Repository Structure
|
| 238 |
+
|
| 239 |
+
| File | Description |
|
| 240 |
+
|---|---|
|
| 241 |
+
| `best_model.pt` | Model weights (`state_dict` format, ~140MB) |
|
| 242 |
+
| `tokenizer_shared_16k.json` | Shared BPE tokenizer (16K vocab) |
|
| 243 |
+
| `config.json` | Full model configuration and training hyperparameters |
|
| 244 |
+
| `src/model.py` | `TransformerTranslator` β complete encoder-decoder architecture |
|
| 245 |
+
| `src/tokenizer.py` | BPE tokenizer training, saving, loading, encoding, decoding |
|
| 246 |
+
| `src/training.py` | Full training loop with early stopping, warmup, cosine decay, AMP |
|
| 247 |
+
| `src/eval.py` | Greedy/beam decoding, chrF scoring, post-processing |
|
| 248 |
+
|
| 249 |
+
## Experimental Journey
|
| 250 |
+
|
| 251 |
+
This project went through several iterations:
|
| 252 |
+
|
| 253 |
+
1. **TED Talks baseline** β IWSLT TED Talks en-ms (~5K pairs). chrF **6.76**. Dataset far too small.
|
| 254 |
+
2. **OPUS-100 pivot** β Switched to OPUS-100 en-ms. chrF **26.39** with 10+2 architecture. Significant improvement but still limited by data quality.
|
| 255 |
+
3. **OpenSubtitles pivot** β Moved to OpenSubtitles v2018 (17.3M raw pairs). Quality filtering pipeline developed.
|
| 256 |
+
4. **Ablation sweeps** β Systematically tested encoder depth (2/4/6/8) and data size (50K/100K/200K/500K). Discovered data size is the dominant factor.
|
| 257 |
+
5. **Final model** β 6+2 tied Transformer, 16K BPE, 490K data, dropout 0.3. chrF **44.57** (greedy, no postprocessing).
|
| 258 |
+
6. **Post-processing fix** β Added punctuation cleanup. chrF **45.62** (greedy). Free +1.05 improvement.
|
| 259 |
+
|
| 260 |
+
## Limitations and Future Work
|
| 261 |
+
|
| 262 |
+
### Current Limitations
|
| 263 |
+
- **Domain specificity:** Trained exclusively on movie/TV subtitles β performance degrades significantly on formal, academic, or technical text (e.g., TED Talks test set gave chrF ~6β26 depending on configuration).
|
| 264 |
+
- **Subword fragmentation:** Rare proper nouns and domain-specific terms get split into character-level fragments (e.g., "Burgundy" β "bur gun dy", "android" β "dan ro id"). A larger vocabulary or byte-level fallback could mitigate this.
|
| 265 |
+
- **16K vocab trade-off:** The compact vocabulary provides dense embeddings but over-segments rare words. A 32K vocabulary might be a better balance.
|
| 266 |
+
- **No backtranslation or data augmentation:** The model trains on natural parallel data only.
|
| 267 |
+
|
| 268 |
+
### Future Improvements
|
| 269 |
+
- **Scale data to 2M+**: Our sweep shows chrF gains ~3.3 points per data doubling. 2M sentences could push chrF to ~50+.
|
| 270 |
+
- **Reduce dropout to 0.1**: With more data, the aggressive 0.3 dropout likely over-regularises.
|
| 271 |
+
- **Byte-level fallback**: Handle rare words more gracefully.
|
| 272 |
+
- **Ensemble decoding**: Combine checkpoints from different training stages.
|
| 273 |
+
|
| 274 |
+
## References
|
| 275 |
+
|
| 276 |
+
- Vaswani, A. et al. (2017). [Attention is All You Need](https://arxiv.org/abs/1706.03762). *NeurIPS*.
|
| 277 |
+
- Kasai, J. et al. (2021). [Deep Encoder, Shallow Decoder: Reevaluating Non-autoregressive Machine Translation](https://arxiv.org/abs/2006.10369). *ICLR*.
|
| 278 |
+
- Press, O. & Wolf, L. (2017). [Using the Output Embedding to Improve Language Models](https://arxiv.org/abs/1608.05859). *EACL*.
|
| 279 |
+
- PopoviΔ, M. (2015). [chrF: character n-gram F-score for automatic MT evaluation](https://aclanthology.org/W15-3049/). *WMT*.
|
| 280 |
+
- Sennrich, R. et al. (2016). [Neural Machine Translation of Rare Words with Subword Units](https://arxiv.org/abs/1508.07909). *ACL*.
|
| 281 |
+
- Lison, P. & Tiedemann, J. (2016). [OpenSubtitles2016: Extracting Large Parallel Corpora from Movie and TV Subtitles](http://www.lrec-conf.org/proceedings/lrec2016/pdf/947_Paper.pdf). *LREC*.
|
| 282 |
+
|
| 283 |
+
## Citation
|
| 284 |
+
|
| 285 |
+
```bibtex
|
| 286 |
+
@misc{astralpotato2025enms,
|
| 287 |
+
title={English-Malay Neural Machine Translation with Deep Encoder, Shallow Decoder Transformer},
|
| 288 |
+
author={AstralPotato},
|
| 289 |
+
year={2025},
|
| 290 |
+
howpublished={IT3103 Advanced Topics in AI, Assignment 2, 2025S2},
|
| 291 |
+
}
|
| 292 |
+
```
|
best_model.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:053df05b5a8c77434507d745eee6fff4c52cddfc72abf27330d71f1e8688c3e3
|
| 3 |
+
size 142469700
|
config.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "TransformerTranslator (6+2 Tied)",
|
| 3 |
+
"vocab_size": 16000,
|
| 4 |
+
"d_model": 512,
|
| 5 |
+
"n_head": 8,
|
| 6 |
+
"num_encoder_layers": 6,
|
| 7 |
+
"num_decoder_layers": 2,
|
| 8 |
+
"d_ff": 2048,
|
| 9 |
+
"dropout": 0.3,
|
| 10 |
+
"max_len": 144,
|
| 11 |
+
"pad_idx": 0,
|
| 12 |
+
"bos_id": 5,
|
| 13 |
+
"eos_id": 6,
|
| 14 |
+
"tied_embeddings": true,
|
| 15 |
+
"pre_norm": true,
|
| 16 |
+
"label_smoothing": 0.1,
|
| 17 |
+
"training": {
|
| 18 |
+
"dataset": "OpenSubtitles v2018 en-ms",
|
| 19 |
+
"train_size": 490000,
|
| 20 |
+
"val_size": 5000,
|
| 21 |
+
"test_size": 5000,
|
| 22 |
+
"epochs_trained": 20,
|
| 23 |
+
"batch_size": 128,
|
| 24 |
+
"lr": 0.0005,
|
| 25 |
+
"warmup_steps": 4000,
|
| 26 |
+
"optimizer": "AdamW",
|
| 27 |
+
"scheduler": "linear warmup + cosine decay",
|
| 28 |
+
"amp": true
|
| 29 |
+
},
|
| 30 |
+
"evaluation": {
|
| 31 |
+
"chrf_greedy": 45.62,
|
| 32 |
+
"chrf_beam5_lp06": 44.99,
|
| 33 |
+
"test_set": "5K in-distribution OpenSubtitles",
|
| 34 |
+
"note": "chrF with post-processing (punctuation cleanup)"
|
| 35 |
+
}
|
| 36 |
+
}
|
src/eval.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation module β greedy / beam-search decoding + chrF scoring.
|
| 3 |
+
=================================================================
|
| 4 |
+
Provides:
|
| 5 |
+
β’ ``greedy_decode`` β auto-regressive greedy decoding.
|
| 6 |
+
β’ ``beam_search_decode`` β beam search with length normalisation.
|
| 7 |
+
β’ ``translate`` β end-to-end: raw English string β Malay string.
|
| 8 |
+
β’ ``compute_chrf`` β corpus-level chrF score via *sacrebleu*.
|
| 9 |
+
β’ ``evaluate`` β decode the full validation set, compute chrF,
|
| 10 |
+
and print sample translations.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from tokenizers import Tokenizer
|
| 21 |
+
|
| 22 |
+
import sacrebleu
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
# 0. Post-processing: fix tokenizer spacing artefacts
|
| 27 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
def postprocess_translation(text: str) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Clean up raw tokenizer decode output:
|
| 31 |
+
1. Remove spaces before punctuation ( ", tuan ." β ", tuan.")
|
| 32 |
+
2. Remove spaces after opening brackets/quotes
|
| 33 |
+
3. Remove spaces before closing brackets/quotes
|
| 34 |
+
4. Capitalise the first letter
|
| 35 |
+
5. Collapse multiple spaces
|
| 36 |
+
"""
|
| 37 |
+
# Remove space before punctuation: . , ? ! ; : ) ] } ' " ...
|
| 38 |
+
text = re.sub(r'\s+([.,?!;:)\]}"\'β¦])', r'\1', text)
|
| 39 |
+
# Remove space after opening brackets/quotes
|
| 40 |
+
text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
|
| 41 |
+
# Fix spaced hyphens in compound words (e.g. "brother - in - arms" β "brother-in-arms")
|
| 42 |
+
text = re.sub(r'\s*-\s*', '-', text)
|
| 43 |
+
# Collapse multiple spaces
|
| 44 |
+
text = re.sub(r'\s{2,}', ' ', text)
|
| 45 |
+
# Strip and capitalise
|
| 46 |
+
text = text.strip()
|
| 47 |
+
if text:
|
| 48 |
+
text = text[0].upper() + text[1:]
|
| 49 |
+
return text
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 53 |
+
# 1. Greedy decoding
|
| 54 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 55 |
+
@torch.no_grad()
|
| 56 |
+
def greedy_decode(
|
| 57 |
+
model: nn.Module,
|
| 58 |
+
src: torch.Tensor,
|
| 59 |
+
bos_id: int,
|
| 60 |
+
eos_id: int,
|
| 61 |
+
pad_id: int = 0,
|
| 62 |
+
max_len: int = 128,
|
| 63 |
+
) -> torch.Tensor:
|
| 64 |
+
"""
|
| 65 |
+
Auto-regressive greedy decoding for a single source sequence.
|
| 66 |
+
|
| 67 |
+
Parameters
|
| 68 |
+
----------
|
| 69 |
+
model : TransformerTranslator
|
| 70 |
+
src : (1, src_len) source token IDs.
|
| 71 |
+
bos_id : beginning-of-sentence token ID.
|
| 72 |
+
eos_id : end-of-sentence token ID.
|
| 73 |
+
pad_id : padding token ID.
|
| 74 |
+
max_len : maximum decoding steps.
|
| 75 |
+
|
| 76 |
+
Returns
|
| 77 |
+
-------
|
| 78 |
+
(1, out_len) generated token IDs (including [BOS], up to [EOS]).
|
| 79 |
+
"""
|
| 80 |
+
device = src.device
|
| 81 |
+
model.eval()
|
| 82 |
+
|
| 83 |
+
# Encode source once
|
| 84 |
+
src_pad_mask = (src == pad_id)
|
| 85 |
+
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
| 86 |
+
|
| 87 |
+
# Start with [BOS]
|
| 88 |
+
ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
|
| 89 |
+
|
| 90 |
+
for _ in range(max_len - 1):
|
| 91 |
+
logits = model.decode(
|
| 92 |
+
ys, memory,
|
| 93 |
+
memory_key_padding_mask=src_pad_mask,
|
| 94 |
+
) # (1, cur_len, vocab)
|
| 95 |
+
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # (1, 1)
|
| 96 |
+
ys = torch.cat([ys, next_token], dim=1)
|
| 97 |
+
|
| 98 |
+
if next_token.item() == eos_id:
|
| 99 |
+
break
|
| 100 |
+
|
| 101 |
+
return ys
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
# 1b. Beam-search decoding
|
| 106 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
@torch.no_grad()
|
| 108 |
+
def beam_search_decode(
|
| 109 |
+
model: nn.Module,
|
| 110 |
+
src: torch.Tensor,
|
| 111 |
+
bos_id: int,
|
| 112 |
+
eos_id: int,
|
| 113 |
+
pad_id: int = 0,
|
| 114 |
+
max_len: int = 128,
|
| 115 |
+
beam_width: int = 5,
|
| 116 |
+
length_penalty: float = 0.6,
|
| 117 |
+
) -> torch.Tensor:
|
| 118 |
+
"""
|
| 119 |
+
Beam-search decoding for a single source sequence.
|
| 120 |
+
|
| 121 |
+
Parameters
|
| 122 |
+
----------
|
| 123 |
+
model : TransformerTranslator
|
| 124 |
+
src : (1, src_len) source token IDs.
|
| 125 |
+
bos_id, eos_id, pad_id : special token IDs.
|
| 126 |
+
max_len : maximum decoding steps.
|
| 127 |
+
beam_width : number of beams to keep at each step.
|
| 128 |
+
length_penalty : Ξ± for length normalisation: score / len^Ξ±.
|
| 129 |
+
|
| 130 |
+
Returns
|
| 131 |
+
-------
|
| 132 |
+
(1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]).
|
| 133 |
+
"""
|
| 134 |
+
device = src.device
|
| 135 |
+
model.eval()
|
| 136 |
+
|
| 137 |
+
# Encode source once
|
| 138 |
+
src_pad_mask = (src == pad_id)
|
| 139 |
+
memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
| 140 |
+
|
| 141 |
+
# Each beam: (log_prob, token_ids_list)
|
| 142 |
+
beams = [(0.0, [bos_id])]
|
| 143 |
+
completed = []
|
| 144 |
+
|
| 145 |
+
for _ in range(max_len - 1):
|
| 146 |
+
candidates = []
|
| 147 |
+
for score, tokens in beams:
|
| 148 |
+
if tokens[-1] == eos_id:
|
| 149 |
+
completed.append((score, tokens))
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
ys = torch.tensor([tokens], dtype=torch.long, device=device)
|
| 153 |
+
logits = model.decode(
|
| 154 |
+
ys, memory,
|
| 155 |
+
memory_key_padding_mask=src_pad_mask,
|
| 156 |
+
) # (1, cur_len, vocab)
|
| 157 |
+
log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
|
| 158 |
+
|
| 159 |
+
topk_log_probs, topk_ids = log_probs.topk(beam_width)
|
| 160 |
+
for k in range(beam_width):
|
| 161 |
+
new_score = score + topk_log_probs[k].item()
|
| 162 |
+
new_tokens = tokens + [topk_ids[k].item()]
|
| 163 |
+
candidates.append((new_score, new_tokens))
|
| 164 |
+
|
| 165 |
+
if not candidates:
|
| 166 |
+
break
|
| 167 |
+
|
| 168 |
+
# Keep top beam_width by length-normalised score
|
| 169 |
+
candidates.sort(
|
| 170 |
+
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
| 171 |
+
reverse=True,
|
| 172 |
+
)
|
| 173 |
+
beams = candidates[:beam_width]
|
| 174 |
+
|
| 175 |
+
# Early exit if all beams have finished
|
| 176 |
+
if all(b[1][-1] == eos_id for b in beams):
|
| 177 |
+
completed.extend(beams)
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
# Add any remaining beams
|
| 181 |
+
completed.extend(beams)
|
| 182 |
+
|
| 183 |
+
# Pick best by length-normalised score
|
| 184 |
+
best = max(
|
| 185 |
+
completed,
|
| 186 |
+
key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
| 187 |
+
)
|
| 188 |
+
return torch.tensor([best[1]], dtype=torch.long, device=device)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 192 |
+
# 2. Translate a raw string
|
| 193 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 194 |
+
def translate(
|
| 195 |
+
model: nn.Module,
|
| 196 |
+
sentence: str,
|
| 197 |
+
src_tokenizer: Tokenizer,
|
| 198 |
+
tgt_tokenizer: Tokenizer,
|
| 199 |
+
bos_id: int,
|
| 200 |
+
eos_id: int,
|
| 201 |
+
pad_id: int = 0,
|
| 202 |
+
max_len: int = 128,
|
| 203 |
+
device: Optional[torch.device] = None,
|
| 204 |
+
beam_width: int = 1,
|
| 205 |
+
length_penalty: float = 0.6,
|
| 206 |
+
) -> str:
|
| 207 |
+
"""Translate a single English sentence to Malay.
|
| 208 |
+
Set beam_width=1 for greedy, >1 for beam search.
|
| 209 |
+
"""
|
| 210 |
+
if device is None:
|
| 211 |
+
device = next(model.parameters()).device
|
| 212 |
+
|
| 213 |
+
# Tokenise source
|
| 214 |
+
src_ids = src_tokenizer.encode(sentence).ids
|
| 215 |
+
src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 216 |
+
|
| 217 |
+
# Decode
|
| 218 |
+
if beam_width > 1:
|
| 219 |
+
out_ids = beam_search_decode(
|
| 220 |
+
model, src, bos_id, eos_id, pad_id, max_len,
|
| 221 |
+
beam_width=beam_width, length_penalty=length_penalty,
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len)
|
| 225 |
+
|
| 226 |
+
# Convert IDs β string (skip special tokens) + clean up spacing
|
| 227 |
+
raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
|
| 228 |
+
return postprocess_translation(raw)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 232 |
+
# 3. Corpus-level chrF
|
| 233 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 234 |
+
def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore:
|
| 235 |
+
"""
|
| 236 |
+
Compute corpus-level chrF score.
|
| 237 |
+
|
| 238 |
+
Parameters
|
| 239 |
+
----------
|
| 240 |
+
hypotheses : list[str]
|
| 241 |
+
System outputs (decoded translations).
|
| 242 |
+
references : list[str]
|
| 243 |
+
Gold reference translations.
|
| 244 |
+
|
| 245 |
+
Returns
|
| 246 |
+
-------
|
| 247 |
+
sacrebleu.CHRFScore β has ``.score`` attribute (0β100 scale).
|
| 248 |
+
"""
|
| 249 |
+
return sacrebleu.corpus_chrf(hypotheses, [references])
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 253 |
+
# 4. Full evaluation driver
|
| 254 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
+
def evaluate(
|
| 256 |
+
model: nn.Module,
|
| 257 |
+
hf_dataset,
|
| 258 |
+
src_tokenizer: Tokenizer,
|
| 259 |
+
tgt_tokenizer: Tokenizer,
|
| 260 |
+
src_lang: str = "en",
|
| 261 |
+
tgt_lang: str = "ms",
|
| 262 |
+
bos_id: int = 5,
|
| 263 |
+
eos_id: int = 6,
|
| 264 |
+
pad_id: int = 0,
|
| 265 |
+
max_len: int = 128,
|
| 266 |
+
device: Optional[torch.device] = None,
|
| 267 |
+
num_samples: int = 5,
|
| 268 |
+
beam_width: int = 1,
|
| 269 |
+
length_penalty: float = 0.6,
|
| 270 |
+
) -> float:
|
| 271 |
+
"""
|
| 272 |
+
Decode every example in *hf_dataset*, compute corpus chrF, and
|
| 273 |
+
print ``num_samples`` side-by-side translations.
|
| 274 |
+
|
| 275 |
+
Set beam_width=1 for greedy, >1 for beam search.
|
| 276 |
+
|
| 277 |
+
Returns
|
| 278 |
+
-------
|
| 279 |
+
chrf_score : float (0β100)
|
| 280 |
+
"""
|
| 281 |
+
if device is None:
|
| 282 |
+
device = next(model.parameters()).device
|
| 283 |
+
|
| 284 |
+
model.eval()
|
| 285 |
+
hypotheses: List[str] = []
|
| 286 |
+
references: List[str] = []
|
| 287 |
+
|
| 288 |
+
for i, example in enumerate(hf_dataset):
|
| 289 |
+
src_text = example["translation"][src_lang]
|
| 290 |
+
ref_text = example["translation"][tgt_lang]
|
| 291 |
+
|
| 292 |
+
hyp_text = translate(
|
| 293 |
+
model, src_text,
|
| 294 |
+
src_tokenizer, tgt_tokenizer,
|
| 295 |
+
bos_id, eos_id, pad_id, max_len, device,
|
| 296 |
+
beam_width=beam_width,
|
| 297 |
+
length_penalty=length_penalty,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
hypotheses.append(hyp_text)
|
| 301 |
+
references.append(ref_text)
|
| 302 |
+
|
| 303 |
+
chrf = compute_chrf(hypotheses, references)
|
| 304 |
+
|
| 305 |
+
# Print samples
|
| 306 |
+
print(f"\n{'='*60}")
|
| 307 |
+
print(f"chrF Score: {chrf.score:.2f}")
|
| 308 |
+
print(f"{'='*60}")
|
| 309 |
+
for i in range(min(num_samples, len(hypotheses))):
|
| 310 |
+
src_text = hf_dataset[i]["translation"][src_lang]
|
| 311 |
+
print(f"\n[{i}] SRC: {src_text[:120]}")
|
| 312 |
+
print(f" REF: {references[i][:120]}")
|
| 313 |
+
print(f" HYP: {hypotheses[i][:120]}")
|
| 314 |
+
|
| 315 |
+
return chrf.score
|
src/model.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
10+2 Tied Transformer for English β Malay Translation
|
| 3 |
+
=======================================================
|
| 4 |
+
An asymmetric encoder-decoder Transformer built on ``torch.nn.Transformer``.
|
| 5 |
+
|
| 6 |
+
Architecture (redesigned for efficient T4 GPU training & inference):
|
| 7 |
+
d_model = 512 (embedding dimension, head_dim = 64)
|
| 8 |
+
n_head = 8 (attention heads)
|
| 9 |
+
encoder layers = 10 (deep encoder for source understanding)
|
| 10 |
+
decoder layers = 2 (shallow decoder for fast generation)
|
| 11 |
+
d_ff = 2048 (feed-forward inner dimension)
|
| 12 |
+
dropout = 0.1
|
| 13 |
+
norm_first = True (pre-norm for training stability)
|
| 14 |
+
shared embeddings = True (single vocab, en+ms share Latin script)
|
| 15 |
+
tied output proj. = True (output reuses embedding weights)
|
| 16 |
+
|
| 17 |
+
Key design choices (see architecture_report.md for full rationale):
|
| 18 |
+
β’ **Asymmetric depth (Kasai et al., 2021):** Encoder depth drives
|
| 19 |
+
translation quality; decoder depth can be aggressively reduced
|
| 20 |
+
with minimal quality loss and ~3Γ faster inference.
|
| 21 |
+
β’ **Shared vocabulary:** English and Malay both use Latin script with
|
| 22 |
+
significant lexical overlap (loanwords, numbers, proper nouns).
|
| 23 |
+
A joint BPE naturally captures cross-lingual subword patterns.
|
| 24 |
+
β’ **Tied output projection (Press & Wolf, 2017):** The decoder's output
|
| 25 |
+
linear layer reuses the shared embedding matrix, saving ~26M params
|
| 26 |
+
and acting as a regulariser.
|
| 27 |
+
β’ **Pre-layer normalisation (Xiong et al., 2020):** Essential for stable
|
| 28 |
+
training of a 10-layer encoder. Places LayerNorm before each sublayer.
|
| 29 |
+
β’ Uses PyTorch's native ``nn.Transformer`` to keep FlashAttention /
|
| 30 |
+
SDPA fused kernels active (PyTorch 2.0+).
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
from __future__ import annotations
|
| 34 |
+
|
| 35 |
+
import math
|
| 36 |
+
from typing import Optional
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
import torch.nn as nn
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Positional Encoding (sinusoidal, from "Attention Is All You Need")
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
class PositionalEncoding(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
Inject positional information via fixed sinusoidal signals.
|
| 48 |
+
|
| 49 |
+
PE(pos, 2i) = sin(pos / 10000^{2i / d_model})
|
| 50 |
+
PE(pos, 2i+1) = cos(pos / 10000^{2i / d_model})
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 56 |
+
|
| 57 |
+
pe = torch.zeros(max_len, d_model) # (max_len, d_model)
|
| 58 |
+
position = torch.arange(0, max_len).unsqueeze(1).float() # (max_len, 1)
|
| 59 |
+
div_term = torch.exp(
|
| 60 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 61 |
+
) # (d_model/2,)
|
| 62 |
+
|
| 63 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 64 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 65 |
+
pe = pe.unsqueeze(0) # (1, max_len, d_model)
|
| 66 |
+
self.register_buffer("pe", pe)
|
| 67 |
+
|
| 68 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
x: (batch, seq_len, d_model)
|
| 72 |
+
Returns:
|
| 73 |
+
(batch, seq_len, d_model) with positional encoding added.
|
| 74 |
+
"""
|
| 75 |
+
x = x + self.pe[:, : x.size(1)]
|
| 76 |
+
return self.dropout(x)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# Full Transformer Model (10+2 Tied)
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
class TransformerTranslator(nn.Module):
|
| 83 |
+
"""
|
| 84 |
+
Asymmetric encoder-decoder Transformer with shared/tied embeddings.
|
| 85 |
+
|
| 86 |
+
Parameters
|
| 87 |
+
----------
|
| 88 |
+
vocab_size : int
|
| 89 |
+
Size of the shared source+target vocabulary.
|
| 90 |
+
d_model : int
|
| 91 |
+
Embedding / hidden dimension.
|
| 92 |
+
n_head : int
|
| 93 |
+
Number of attention heads.
|
| 94 |
+
num_encoder_layers : int
|
| 95 |
+
Number of encoder blocks (default 10).
|
| 96 |
+
num_decoder_layers : int
|
| 97 |
+
Number of decoder blocks (default 2).
|
| 98 |
+
d_ff : int
|
| 99 |
+
Feed-forward inner dimension.
|
| 100 |
+
dropout : float
|
| 101 |
+
Dropout rate.
|
| 102 |
+
max_len : int
|
| 103 |
+
Maximum sequence length for positional encoding.
|
| 104 |
+
pad_idx : int
|
| 105 |
+
Padding token ID (used to create padding masks).
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
vocab_size: int,
|
| 111 |
+
d_model: int = 512,
|
| 112 |
+
n_head: int = 8,
|
| 113 |
+
num_encoder_layers: int = 10,
|
| 114 |
+
num_decoder_layers: int = 2,
|
| 115 |
+
d_ff: int = 2048,
|
| 116 |
+
dropout: float = 0.1,
|
| 117 |
+
max_len: int = 512,
|
| 118 |
+
pad_idx: int = 0,
|
| 119 |
+
):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.pad_idx = pad_idx
|
| 122 |
+
self.d_model = d_model
|
| 123 |
+
|
| 124 |
+
# --- Shared embedding (one matrix for both enc & dec) -------------
|
| 125 |
+
self.shared_embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
|
| 126 |
+
self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)
|
| 127 |
+
self.embed_scale = math.sqrt(d_model)
|
| 128 |
+
|
| 129 |
+
# --- Core Transformer (asymmetric, pre-norm) ----------------------
|
| 130 |
+
self.transformer = nn.Transformer(
|
| 131 |
+
d_model=d_model,
|
| 132 |
+
nhead=n_head,
|
| 133 |
+
num_encoder_layers=num_encoder_layers,
|
| 134 |
+
num_decoder_layers=num_decoder_layers,
|
| 135 |
+
dim_feedforward=d_ff,
|
| 136 |
+
dropout=dropout,
|
| 137 |
+
batch_first=True,
|
| 138 |
+
norm_first=True, # pre-layer norm for stability
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# --- Tied output projection (reuses embedding weights) ------------
|
| 142 |
+
# No separate nn.Linear β forward() uses F.linear with shared weights
|
| 143 |
+
self.output_bias = nn.Parameter(torch.zeros(vocab_size))
|
| 144 |
+
|
| 145 |
+
# --- Initialize weights -------------------------------------------
|
| 146 |
+
self._init_weights()
|
| 147 |
+
|
| 148 |
+
def _embed(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 149 |
+
"""Shared embedding + scale + positional encoding."""
|
| 150 |
+
return self.pos_encoding(self.shared_embedding(tokens) * self.embed_scale)
|
| 151 |
+
|
| 152 |
+
def _init_weights(self):
|
| 153 |
+
"""Xavier-uniform initialization for embeddings."""
|
| 154 |
+
nn.init.normal_(self.shared_embedding.weight, mean=0, std=self.d_model ** -0.5)
|
| 155 |
+
# Zero out padding embedding
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
self.shared_embedding.weight[self.pad_idx].zero_()
|
| 158 |
+
|
| 159 |
+
# ------------------------------------------------------------------
|
| 160 |
+
# Mask utilities
|
| 161 |
+
# ------------------------------------------------------------------
|
| 162 |
+
@staticmethod
|
| 163 |
+
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
Causal mask for the decoder: prevents attending to future positions.
|
| 166 |
+
Returns a (sz, sz) boolean mask where True = blocked.
|
| 167 |
+
"""
|
| 168 |
+
return torch.triu(torch.ones(sz, sz, device=device, dtype=torch.bool), diagonal=1)
|
| 169 |
+
|
| 170 |
+
def _make_pad_mask(self, x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
"""
|
| 172 |
+
Create a padding mask: True where token == pad_idx.
|
| 173 |
+
Shape: (batch, seq_len)
|
| 174 |
+
"""
|
| 175 |
+
return x == self.pad_idx
|
| 176 |
+
|
| 177 |
+
# ------------------------------------------------------------------
|
| 178 |
+
# Forward
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
def forward(
|
| 181 |
+
self,
|
| 182 |
+
src: torch.Tensor,
|
| 183 |
+
tgt: torch.Tensor,
|
| 184 |
+
src_key_padding_mask: Optional[torch.Tensor] = None,
|
| 185 |
+
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| 186 |
+
) -> torch.Tensor:
|
| 187 |
+
"""
|
| 188 |
+
Args:
|
| 189 |
+
src: (batch, src_len) source token IDs.
|
| 190 |
+
tgt: (batch, tgt_len) target token IDs (teacher-forced).
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
logits: (batch, tgt_len, vocab_size)
|
| 194 |
+
"""
|
| 195 |
+
# Build masks if not provided
|
| 196 |
+
if src_key_padding_mask is None:
|
| 197 |
+
src_key_padding_mask = self._make_pad_mask(src)
|
| 198 |
+
if tgt_key_padding_mask is None:
|
| 199 |
+
tgt_key_padding_mask = self._make_pad_mask(tgt)
|
| 200 |
+
|
| 201 |
+
# Causal mask for decoder
|
| 202 |
+
tgt_len = tgt.size(1)
|
| 203 |
+
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
| 204 |
+
|
| 205 |
+
# Shared embeddings for both encoder and decoder
|
| 206 |
+
src_emb = self._embed(src)
|
| 207 |
+
tgt_emb = self._embed(tgt)
|
| 208 |
+
|
| 209 |
+
# Transformer forward
|
| 210 |
+
out = self.transformer(
|
| 211 |
+
src=src_emb,
|
| 212 |
+
tgt=tgt_emb,
|
| 213 |
+
tgt_mask=tgt_mask,
|
| 214 |
+
src_key_padding_mask=src_key_padding_mask,
|
| 215 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 216 |
+
memory_key_padding_mask=src_key_padding_mask,
|
| 217 |
+
) # (batch, tgt_len, d_model)
|
| 218 |
+
|
| 219 |
+
# Tied output projection: logits = out @ embedding_weights.T + bias
|
| 220 |
+
logits = torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
| 221 |
+
return logits
|
| 222 |
+
|
| 223 |
+
# ------------------------------------------------------------------
|
| 224 |
+
# Inference helpers
|
| 225 |
+
# ------------------------------------------------------------------
|
| 226 |
+
def encode(self, src: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 227 |
+
"""Run only the encoder. Returns memory: (batch, src_len, d_model)."""
|
| 228 |
+
if src_key_padding_mask is None:
|
| 229 |
+
src_key_padding_mask = self._make_pad_mask(src)
|
| 230 |
+
src_emb = self._embed(src)
|
| 231 |
+
return self.transformer.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
|
| 232 |
+
|
| 233 |
+
def decode(
|
| 234 |
+
self,
|
| 235 |
+
tgt: torch.Tensor,
|
| 236 |
+
memory: torch.Tensor,
|
| 237 |
+
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
| 238 |
+
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
"""Run only the decoder given encoder memory. Returns logits."""
|
| 241 |
+
if tgt_key_padding_mask is None:
|
| 242 |
+
tgt_key_padding_mask = self._make_pad_mask(tgt)
|
| 243 |
+
tgt_len = tgt.size(1)
|
| 244 |
+
tgt_mask = self.generate_square_subsequent_mask(tgt_len, tgt.device)
|
| 245 |
+
tgt_emb = self._embed(tgt)
|
| 246 |
+
out = self.transformer.decoder(
|
| 247 |
+
tgt_emb,
|
| 248 |
+
memory,
|
| 249 |
+
tgt_mask=tgt_mask,
|
| 250 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
| 251 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
| 252 |
+
)
|
| 253 |
+
return torch.nn.functional.linear(out, self.shared_embedding.weight, self.output_bias)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ---------------------------------------------------------------------------
|
| 257 |
+
# Helper: count parameters
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
def count_parameters(model: nn.Module) -> int:
|
| 260 |
+
"""Return the number of trainable parameters."""
|
| 261 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
# ---------------------------------------------------------------------------
|
| 265 |
+
# Helper: build model
|
| 266 |
+
# ---------------------------------------------------------------------------
|
| 267 |
+
def build_model(
|
| 268 |
+
vocab_size: int,
|
| 269 |
+
pad_idx: int = 0,
|
| 270 |
+
device: Optional[torch.device] = None,
|
| 271 |
+
**kwargs,
|
| 272 |
+
) -> TransformerTranslator:
|
| 273 |
+
"""
|
| 274 |
+
Build and return a TransformerTranslator with default hyperparameters.
|
| 275 |
+
|
| 276 |
+
Any kwarg (d_model, n_head, etc.) overrides the default.
|
| 277 |
+
"""
|
| 278 |
+
if device is None:
|
| 279 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 280 |
+
|
| 281 |
+
model = TransformerTranslator(
|
| 282 |
+
vocab_size=vocab_size,
|
| 283 |
+
pad_idx=pad_idx,
|
| 284 |
+
**kwargs,
|
| 285 |
+
).to(device)
|
| 286 |
+
|
| 287 |
+
return model
|
src/tokenizer.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Byte-Pair Encoding (BPE) Tokenizer for English-Malay Translation
|
| 3 |
+
=================================================================
|
| 4 |
+
We support two modes:
|
| 5 |
+
1. **Shared tokenizer** (preferred for 10+2 Tied Transformer):
|
| 6 |
+
A single BPE tokenizer trained on the concatenated en+ms corpus.
|
| 7 |
+
Both encoder and decoder share the same vocabulary.
|
| 8 |
+
2. **Separate tokenizers** (legacy):
|
| 9 |
+
Two independent BPE tokenizers, one per language.
|
| 10 |
+
|
| 11 |
+
Why BPE?
|
| 12 |
+
β’ Handles subword units, so rare / unseen words are decomposed into
|
| 13 |
+
known subword pieces instead of mapping to [UNK].
|
| 14 |
+
β’ Malay is morphologically rich (prefixes: me-, ber-, di-; suffixes:
|
| 15 |
+
-kan, -an, -i). BPE naturally learns these affixes as subword units,
|
| 16 |
+
giving much better coverage than a word-level tokenizer.
|
| 17 |
+
β’ Keeps vocabulary compact while still reaching high coverage on both
|
| 18 |
+
English and Malay.
|
| 19 |
+
|
| 20 |
+
Why shared vocabulary for en-ms?
|
| 21 |
+
β’ Both languages use the Latin script with significant lexical overlap
|
| 22 |
+
(loanwords: "teknologi", "matematik", "universiti"; numbers; proper nouns).
|
| 23 |
+
β’ A joint BPE captures cross-lingual subword patterns and enables
|
| 24 |
+
tied embeddings in the model (Press & Wolf, 2017), saving ~26M params.
|
| 25 |
+
|
| 26 |
+
Design choices:
|
| 27 |
+
β’ NFKC normalisation + lowercase β ensures consistent encoding of
|
| 28 |
+
Unicode characters and removes casing noise.
|
| 29 |
+
β’ Whitespace pre-tokeniser β splits on spaces before BPE merges; simple
|
| 30 |
+
and effective for Latin-script languages.
|
| 31 |
+
β’ Special tokens:
|
| 32 |
+
[PAD] β padding for uniform sequence lengths in batches
|
| 33 |
+
[UNK] β fallback for unknown characters
|
| 34 |
+
[CLS] β beginning-of-sequence / classification token
|
| 35 |
+
[SEP] β separator (unused in basic seq2seq but reserved)
|
| 36 |
+
[MASK] β reserved for masked-LM pretraining objectives
|
| 37 |
+
[BOS] β beginning of sentence (fed to decoder at step 0)
|
| 38 |
+
[EOS] β end of sentence (signals the decoder to stop)
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
from __future__ import annotations
|
| 42 |
+
|
| 43 |
+
import os
|
| 44 |
+
import tempfile
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
from typing import Iterator, List, Optional, Union
|
| 47 |
+
|
| 48 |
+
from tokenizers import Tokenizer
|
| 49 |
+
from tokenizers.models import BPE
|
| 50 |
+
from tokenizers.trainers import BpeTrainer
|
| 51 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 52 |
+
from tokenizers.normalizers import Sequence, NFKC, Lowercase
|
| 53 |
+
from tokenizers.processors import TemplateProcessing
|
| 54 |
+
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Constants
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
SPECIAL_TOKENS: List[str] = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "[BOS]", "[EOS]"]
|
| 59 |
+
PAD_TOKEN = "[PAD]"
|
| 60 |
+
UNK_TOKEN = "[UNK]"
|
| 61 |
+
CLS_TOKEN = "[CLS]"
|
| 62 |
+
SEP_TOKEN = "[SEP]"
|
| 63 |
+
MASK_TOKEN = "[MASK]"
|
| 64 |
+
BOS_TOKEN = "[BOS]"
|
| 65 |
+
EOS_TOKEN = "[EOS]"
|
| 66 |
+
|
| 67 |
+
DEFAULT_VOCAB_SIZE = 50_000
|
| 68 |
+
DEFAULT_MIN_FREQUENCY = 2
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
# Helper: write an iterator of strings to a temporary file (needed by the
|
| 73 |
+
# HuggingFace `tokenizers` training API which expects file paths).
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
def _write_texts_to_tmpfile(texts: Iterator[str]) -> str:
|
| 76 |
+
"""Write an iterable of strings to a temp file, one per line. Returns path."""
|
| 77 |
+
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8")
|
| 78 |
+
for line in texts:
|
| 79 |
+
line = line.strip()
|
| 80 |
+
if line:
|
| 81 |
+
tmp.write(line + "\n")
|
| 82 |
+
tmp.close()
|
| 83 |
+
return tmp.name
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Core: build & train a BPE tokenizer
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
def build_tokenizer(
|
| 90 |
+
vocab_size: int = DEFAULT_VOCAB_SIZE,
|
| 91 |
+
min_frequency: int = DEFAULT_MIN_FREQUENCY,
|
| 92 |
+
) -> tuple[Tokenizer, BpeTrainer]:
|
| 93 |
+
"""
|
| 94 |
+
Create an *untrained* BPE tokenizer and its trainer.
|
| 95 |
+
|
| 96 |
+
Returns
|
| 97 |
+
-------
|
| 98 |
+
tokenizer : Tokenizer
|
| 99 |
+
Ready to call ``tokenizer.train(files, trainer)``.
|
| 100 |
+
trainer : BpeTrainer
|
| 101 |
+
Configured trainer instance.
|
| 102 |
+
"""
|
| 103 |
+
tokenizer = Tokenizer(BPE(unk_token=UNK_TOKEN))
|
| 104 |
+
|
| 105 |
+
# --- Normalisation: NFKC (canonical Unicode) + lowercase -------------
|
| 106 |
+
tokenizer.normalizer = Sequence([NFKC(), Lowercase()])
|
| 107 |
+
|
| 108 |
+
# --- Pre-tokenisation: split on whitespace ---------------------------
|
| 109 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 110 |
+
|
| 111 |
+
# --- Trainer ---------------------------------------------------------
|
| 112 |
+
trainer = BpeTrainer(
|
| 113 |
+
vocab_size=vocab_size,
|
| 114 |
+
min_frequency=min_frequency,
|
| 115 |
+
special_tokens=SPECIAL_TOKENS,
|
| 116 |
+
show_progress=True,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return tokenizer, trainer
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def train_tokenizer(
|
| 123 |
+
texts: Union[List[str], Iterator[str]],
|
| 124 |
+
vocab_size: int = DEFAULT_VOCAB_SIZE,
|
| 125 |
+
min_frequency: int = DEFAULT_MIN_FREQUENCY,
|
| 126 |
+
files: Optional[List[str]] = None,
|
| 127 |
+
) -> Tokenizer:
|
| 128 |
+
"""
|
| 129 |
+
Train a BPE tokenizer on the given texts **or** files.
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
texts : list[str] or iterator of str, optional
|
| 134 |
+
Raw sentences. Ignored when *files* is provided.
|
| 135 |
+
vocab_size : int
|
| 136 |
+
Target vocabulary size (default 30 000).
|
| 137 |
+
min_frequency : int
|
| 138 |
+
Minimum frequency for a pair to be merged.
|
| 139 |
+
files : list[str], optional
|
| 140 |
+
Paths to plain-text files (one sentence per line).
|
| 141 |
+
|
| 142 |
+
Returns
|
| 143 |
+
-------
|
| 144 |
+
Tokenizer
|
| 145 |
+
Trained tokenizer ready for encoding / decoding.
|
| 146 |
+
"""
|
| 147 |
+
tokenizer, trainer = build_tokenizer(vocab_size, min_frequency)
|
| 148 |
+
|
| 149 |
+
if files is not None:
|
| 150 |
+
tokenizer.train(files, trainer)
|
| 151 |
+
else:
|
| 152 |
+
# Write texts to a temporary file so we can use the fast Rust trainer
|
| 153 |
+
tmp_path = _write_texts_to_tmpfile(iter(texts))
|
| 154 |
+
try:
|
| 155 |
+
tokenizer.train([tmp_path], trainer)
|
| 156 |
+
finally:
|
| 157 |
+
os.remove(tmp_path)
|
| 158 |
+
|
| 159 |
+
# --- Post-processing: wrap every encoded sequence with [BOS] β¦ [EOS] -
|
| 160 |
+
bos_id = tokenizer.token_to_id(BOS_TOKEN)
|
| 161 |
+
eos_id = tokenizer.token_to_id(EOS_TOKEN)
|
| 162 |
+
tokenizer.post_processor = TemplateProcessing(
|
| 163 |
+
single=f"[BOS]:0 $A:0 [EOS]:0",
|
| 164 |
+
pair=f"[BOS]:0 $A:0 [EOS]:0 [BOS]:1 $B:1 [EOS]:1",
|
| 165 |
+
special_tokens=[
|
| 166 |
+
("[BOS]", bos_id),
|
| 167 |
+
("[EOS]", eos_id),
|
| 168 |
+
],
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return tokenizer
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Convenience wrappers for saving / loading
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
def save_tokenizer(tokenizer: Tokenizer, path: Union[str, Path]) -> None:
|
| 178 |
+
"""Save a trained tokenizer to a JSON file."""
|
| 179 |
+
path = Path(path)
|
| 180 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 181 |
+
tokenizer.save(str(path))
|
| 182 |
+
print(f"[β] Tokenizer saved β {path}")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def load_tokenizer(path: Union[str, Path]) -> Tokenizer:
|
| 186 |
+
"""Load a previously saved tokenizer from a JSON file."""
|
| 187 |
+
tokenizer = Tokenizer.from_file(str(path))
|
| 188 |
+
print(f"[β] Tokenizer loaded β {path}")
|
| 189 |
+
return tokenizer
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ---------------------------------------------------------------------------
|
| 193 |
+
# Encoding / decoding helpers
|
| 194 |
+
# ---------------------------------------------------------------------------
|
| 195 |
+
def encode(tokenizer: Tokenizer, text: str) -> List[int]:
|
| 196 |
+
"""Encode a single string and return token IDs (includes [BOS]/[EOS])."""
|
| 197 |
+
return tokenizer.encode(text).ids
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def decode(tokenizer: Tokenizer, ids: List[int]) -> str:
|
| 201 |
+
"""Decode token IDs back to a string, skipping special tokens."""
|
| 202 |
+
return tokenizer.decode(ids, skip_special_tokens=True)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def get_vocab_size(tokenizer: Tokenizer) -> int:
|
| 206 |
+
"""Return the size of the tokenizer's vocabulary."""
|
| 207 |
+
return tokenizer.get_vocab_size()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def token_to_id(tokenizer: Tokenizer, token: str) -> Optional[int]:
|
| 211 |
+
"""Look up the integer ID for a single token string."""
|
| 212 |
+
return tokenizer.token_to_id(token)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def id_to_token(tokenizer: Tokenizer, id: int) -> Optional[str]:
|
| 216 |
+
"""Look up the token string for a single integer ID."""
|
| 217 |
+
return tokenizer.id_to_token(id)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ---------------------------------------------------------------------------
|
| 221 |
+
# High-level: train a SHARED tokenizer on both languages (for tied embeddings)
|
| 222 |
+
# ---------------------------------------------------------------------------
|
| 223 |
+
def train_shared_tokenizer_from_dataset(
|
| 224 |
+
dataset,
|
| 225 |
+
src_lang: str = "en",
|
| 226 |
+
tgt_lang: str = "ms",
|
| 227 |
+
vocab_size: int = DEFAULT_VOCAB_SIZE,
|
| 228 |
+
save_dir: Union[str, Path] = "tokenizer",
|
| 229 |
+
) -> Tokenizer:
|
| 230 |
+
"""
|
| 231 |
+
Train a single shared BPE tokenizer on the concatenated en+ms corpus.
|
| 232 |
+
|
| 233 |
+
This is used with the 10+2 Tied Transformer architecture, where both
|
| 234 |
+
encoder and decoder share the same vocabulary and embedding matrix.
|
| 235 |
+
|
| 236 |
+
Parameters
|
| 237 |
+
----------
|
| 238 |
+
dataset : datasets.Dataset
|
| 239 |
+
A HuggingFace dataset split where each example has a ``'translation'``
|
| 240 |
+
dict with keys for each language code.
|
| 241 |
+
src_lang : str
|
| 242 |
+
Source language code (default ``'en'``).
|
| 243 |
+
tgt_lang : str
|
| 244 |
+
Target language code (default ``'ms'``).
|
| 245 |
+
vocab_size : int
|
| 246 |
+
Vocabulary size for the shared tokenizer.
|
| 247 |
+
save_dir : str or Path
|
| 248 |
+
Directory to save the trained tokenizer JSON file.
|
| 249 |
+
|
| 250 |
+
Returns
|
| 251 |
+
-------
|
| 252 |
+
Tokenizer
|
| 253 |
+
A single shared tokenizer for both languages.
|
| 254 |
+
"""
|
| 255 |
+
save_dir = Path(save_dir)
|
| 256 |
+
|
| 257 |
+
# Concatenate all source and target sentences into one corpus
|
| 258 |
+
src_texts = [example["translation"][src_lang] for example in dataset]
|
| 259 |
+
tgt_texts = [example["translation"][tgt_lang] for example in dataset]
|
| 260 |
+
all_texts = src_texts + tgt_texts
|
| 261 |
+
|
| 262 |
+
print(f"Training shared BPE tokenizer on {len(all_texts):,} sentences "
|
| 263 |
+
f"({len(src_texts):,} {src_lang} + {len(tgt_texts):,} {tgt_lang}) β¦")
|
| 264 |
+
shared_tokenizer = train_tokenizer(all_texts, vocab_size=vocab_size)
|
| 265 |
+
save_tokenizer(shared_tokenizer, save_dir / "tokenizer_shared.json")
|
| 266 |
+
|
| 267 |
+
# Sanity check
|
| 268 |
+
for name, sample in [(src_lang, src_texts[0]), (tgt_lang, tgt_texts[0])]:
|
| 269 |
+
enc = shared_tokenizer.encode(sample)
|
| 270 |
+
print(f"\n[{name}] Sample: {sample[:80]}β¦")
|
| 271 |
+
print(f" Tokens : {enc.tokens[:15]}β¦")
|
| 272 |
+
print(f" IDs : {enc.ids[:15]}β¦")
|
| 273 |
+
print(f" Decoded: {shared_tokenizer.decode(enc.ids, skip_special_tokens=True)[:80]}β¦")
|
| 274 |
+
|
| 275 |
+
print(f"\n[β] Shared tokenizer trained and saved to {save_dir}/tokenizer_shared.json")
|
| 276 |
+
return shared_tokenizer
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# ---------------------------------------------------------------------------
|
| 280 |
+
# High-level: train source (English) & target (Malay) tokenizers from a
|
| 281 |
+
# HuggingFace dataset split.
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
def train_tokenizers_from_dataset(
|
| 284 |
+
dataset,
|
| 285 |
+
src_lang: str = "en",
|
| 286 |
+
tgt_lang: str = "ms",
|
| 287 |
+
vocab_size: int = DEFAULT_VOCAB_SIZE,
|
| 288 |
+
save_dir: Union[str, Path] = "tokenizer",
|
| 289 |
+
) -> tuple[Tokenizer, Tokenizer]:
|
| 290 |
+
"""
|
| 291 |
+
Train separate BPE tokenizers for source and target languages.
|
| 292 |
+
|
| 293 |
+
Parameters
|
| 294 |
+
----------
|
| 295 |
+
dataset : datasets.Dataset
|
| 296 |
+
A HuggingFace dataset split (e.g. ``dataset['train']``) where each
|
| 297 |
+
example has a ``'translation'`` dict with keys for each language code.
|
| 298 |
+
src_lang : str
|
| 299 |
+
Source language code (default ``'en'``).
|
| 300 |
+
tgt_lang : str
|
| 301 |
+
Target language code (default ``'ms'``).
|
| 302 |
+
vocab_size : int
|
| 303 |
+
Vocabulary size for each tokenizer.
|
| 304 |
+
save_dir : str or Path
|
| 305 |
+
Directory to save the trained tokenizer JSON files.
|
| 306 |
+
|
| 307 |
+
Returns
|
| 308 |
+
-------
|
| 309 |
+
(src_tokenizer, tgt_tokenizer)
|
| 310 |
+
"""
|
| 311 |
+
save_dir = Path(save_dir)
|
| 312 |
+
|
| 313 |
+
# Extract raw sentences from the dataset
|
| 314 |
+
src_texts = [example["translation"][src_lang] for example in dataset]
|
| 315 |
+
tgt_texts = [example["translation"][tgt_lang] for example in dataset]
|
| 316 |
+
|
| 317 |
+
print(f"Training source tokenizer ({src_lang}) on {len(src_texts):,} sentences β¦")
|
| 318 |
+
src_tokenizer = train_tokenizer(src_texts, vocab_size=vocab_size)
|
| 319 |
+
save_tokenizer(src_tokenizer, save_dir / f"tokenizer_{src_lang}.json")
|
| 320 |
+
|
| 321 |
+
print(f"Training target tokenizer ({tgt_lang}) on {len(tgt_texts):,} sentences β¦")
|
| 322 |
+
tgt_tokenizer = train_tokenizer(tgt_texts, vocab_size=vocab_size)
|
| 323 |
+
save_tokenizer(tgt_tokenizer, save_dir / f"tokenizer_{tgt_lang}.json")
|
| 324 |
+
|
| 325 |
+
# Quick sanity check
|
| 326 |
+
for name, tok, sample in [
|
| 327 |
+
(src_lang, src_tokenizer, src_texts[0]),
|
| 328 |
+
(tgt_lang, tgt_tokenizer, tgt_texts[0]),
|
| 329 |
+
]:
|
| 330 |
+
enc = tok.encode(sample)
|
| 331 |
+
print(f"\n[{name}] Sample: {sample[:80]}β¦")
|
| 332 |
+
print(f" Tokens : {enc.tokens[:15]}β¦")
|
| 333 |
+
print(f" IDs : {enc.ids[:15]}β¦")
|
| 334 |
+
print(f" Decoded: {tok.decode(enc.ids, skip_special_tokens=True)[:80]}β¦")
|
| 335 |
+
|
| 336 |
+
print(f"\n[β] Both tokenizers trained and saved to {save_dir}/")
|
| 337 |
+
return src_tokenizer, tgt_tokenizer
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
# Standalone usage
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
if __name__ == "__main__":
|
| 344 |
+
from datasets import load_from_disk
|
| 345 |
+
|
| 346 |
+
print("Loading TED Talks IWSLT dataset (en β ms, 2016) β¦")
|
| 347 |
+
ds = load_from_disk("dataset/en_ms_2016")
|
| 348 |
+
|
| 349 |
+
src_tok, tgt_tok = train_tokenizers_from_dataset(
|
| 350 |
+
ds,
|
| 351 |
+
src_lang="en",
|
| 352 |
+
tgt_lang="ms",
|
| 353 |
+
vocab_size=DEFAULT_VOCAB_SIZE,
|
| 354 |
+
save_dir="tokenizer",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
print(f"\nEnglish vocab size : {get_vocab_size(src_tok):,}")
|
| 358 |
+
print(f"Malay vocab size : {get_vocab_size(tgt_tok):,}")
|
| 359 |
+
print(f"[PAD] id (en) : {token_to_id(src_tok, PAD_TOKEN)}")
|
| 360 |
+
print(f"[EOS] id (ms) : {token_to_id(tgt_tok, EOS_TOKEN)}")
|
src/training.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training loop for the Transformer translator.
|
| 3 |
+
===============================================
|
| 4 |
+
Provides:
|
| 5 |
+
β’ ``TranslationDataset`` β a PyTorch Dataset that tokenises and pads
|
| 6 |
+
source/target sentence pairs.
|
| 7 |
+
β’ ``create_dataloaders`` β builds train / validation DataLoaders with
|
| 8 |
+
an 90/10 split.
|
| 9 |
+
β’ ``train_one_epoch`` β one full pass over the training set.
|
| 10 |
+
β’ ``evaluate_loss`` β average loss on the validation set.
|
| 11 |
+
β’ ``train`` β full training driver with logging, LR
|
| 12 |
+
scheduling, checkpointing, and early stopping.
|
| 13 |
+
|
| 14 |
+
Design choices:
|
| 15 |
+
β’ Label-smoothed cross-entropy (smoothing = 0.1) for better
|
| 16 |
+
generalisation.
|
| 17 |
+
β’ AdamW with a linear-warmup + cosine-decay schedule (stable for
|
| 18 |
+
small datasets).
|
| 19 |
+
β’ Mixed-precision (AMP) with ``torch.amp`` for speed / memory on T4.
|
| 20 |
+
β’ Gradient clipping at max_norm = 1.0 to avoid exploding gradients.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import math
|
| 26 |
+
import os
|
| 27 |
+
import time
|
| 28 |
+
from dataclasses import dataclass, field
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from typing import List, Optional, Tuple
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 35 |
+
from tokenizers import Tokenizer
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
# 1. Translation Dataset
|
| 40 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
class TranslationDataset(Dataset):
|
| 42 |
+
"""
|
| 43 |
+
Wraps a HuggingFace dataset of translation pairs into a PyTorch
|
| 44 |
+
Dataset that returns padded token-ID tensors.
|
| 45 |
+
|
| 46 |
+
Each ``__getitem__`` returns::
|
| 47 |
+
{
|
| 48 |
+
"src": LongTensor[max_len], # source token IDs (padded)
|
| 49 |
+
"tgt": LongTensor[max_len], # target input (with [BOS], no final [EOS])
|
| 50 |
+
"label": LongTensor[max_len], # target labels (no [BOS], with [EOS])
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
The *tgt* / *label* split implements **teacher forcing**: the decoder
|
| 54 |
+
receives ``[BOS] w1 w2 β¦`` and must predict ``w1 w2 β¦ [EOS]``.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
hf_dataset,
|
| 60 |
+
src_tokenizer: Tokenizer,
|
| 61 |
+
tgt_tokenizer: Tokenizer,
|
| 62 |
+
src_lang: str = "en",
|
| 63 |
+
tgt_lang: str = "ms",
|
| 64 |
+
max_len: int = 128,
|
| 65 |
+
pad_id: int = 0,
|
| 66 |
+
):
|
| 67 |
+
self.data = hf_dataset
|
| 68 |
+
self.src_tok = src_tokenizer
|
| 69 |
+
self.tgt_tok = tgt_tokenizer
|
| 70 |
+
self.src_lang = src_lang
|
| 71 |
+
self.tgt_lang = tgt_lang
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.pad_id = pad_id
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return len(self.data)
|
| 77 |
+
|
| 78 |
+
def _pad(self, ids: List[int]) -> List[int]:
|
| 79 |
+
"""Truncate to max_len, then right-pad with pad_id."""
|
| 80 |
+
ids = ids[: self.max_len]
|
| 81 |
+
return ids + [self.pad_id] * (self.max_len - len(ids))
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, idx: int) -> dict:
|
| 84 |
+
pair = self.data[idx]["translation"]
|
| 85 |
+
|
| 86 |
+
# Encode (includes [BOS] β¦ [EOS] from post-processor)
|
| 87 |
+
src_ids = self.src_tok.encode(pair[self.src_lang]).ids
|
| 88 |
+
tgt_ids = self.tgt_tok.encode(pair[self.tgt_lang]).ids
|
| 89 |
+
|
| 90 |
+
# Teacher-forcing split:
|
| 91 |
+
# tgt_input = [BOS] w1 w2 β¦ wN (drop last token)
|
| 92 |
+
# tgt_label = w1 w2 β¦ wN [EOS] (drop first token)
|
| 93 |
+
tgt_input = tgt_ids[:-1]
|
| 94 |
+
tgt_label = tgt_ids[1:]
|
| 95 |
+
|
| 96 |
+
return {
|
| 97 |
+
"src": torch.tensor(self._pad(src_ids), dtype=torch.long),
|
| 98 |
+
"tgt": torch.tensor(self._pad(tgt_input), dtype=torch.long),
|
| 99 |
+
"label": torch.tensor(self._pad(tgt_label), dtype=torch.long),
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
# 2. DataLoader factory
|
| 105 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
def create_dataloaders(
|
| 107 |
+
hf_dataset,
|
| 108 |
+
src_tokenizer: Tokenizer,
|
| 109 |
+
tgt_tokenizer: Tokenizer,
|
| 110 |
+
src_lang: str = "en",
|
| 111 |
+
tgt_lang: str = "ms",
|
| 112 |
+
max_len: int = 128,
|
| 113 |
+
batch_size: int = 32,
|
| 114 |
+
val_ratio: float = 0.1,
|
| 115 |
+
pad_id: int = 0,
|
| 116 |
+
seed: int = 42,
|
| 117 |
+
) -> Tuple[DataLoader, DataLoader, TranslationDataset]:
|
| 118 |
+
"""
|
| 119 |
+
Build training and validation DataLoaders from a HuggingFace dataset.
|
| 120 |
+
|
| 121 |
+
Returns
|
| 122 |
+
-------
|
| 123 |
+
train_loader, val_loader, full_dataset
|
| 124 |
+
"""
|
| 125 |
+
full_ds = TranslationDataset(
|
| 126 |
+
hf_dataset, src_tokenizer, tgt_tokenizer,
|
| 127 |
+
src_lang, tgt_lang, max_len, pad_id,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
val_size = max(1, int(len(full_ds) * val_ratio))
|
| 131 |
+
train_size = len(full_ds) - val_size
|
| 132 |
+
|
| 133 |
+
generator = torch.Generator().manual_seed(seed)
|
| 134 |
+
train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=generator)
|
| 135 |
+
|
| 136 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
| 137 |
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
|
| 138 |
+
|
| 139 |
+
print(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}")
|
| 140 |
+
return train_loader, val_loader, full_ds
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 144 |
+
# 3. Training configuration dataclass
|
| 145 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 146 |
+
@dataclass
|
| 147 |
+
class TrainConfig:
|
| 148 |
+
"""All tuneable knobs in one place."""
|
| 149 |
+
epochs: int = 50
|
| 150 |
+
batch_size: int = 32
|
| 151 |
+
max_len: int = 128
|
| 152 |
+
lr: float = 5e-4
|
| 153 |
+
warmup_steps: int = 200
|
| 154 |
+
label_smoothing: float = 0.1
|
| 155 |
+
grad_clip: float = 1.0
|
| 156 |
+
use_amp: bool = True
|
| 157 |
+
val_ratio: float = 0.1
|
| 158 |
+
checkpoint_dir: str = "training/checkpoints"
|
| 159 |
+
log_every: int = 10 # print loss every N steps
|
| 160 |
+
patience: int = 10 # early-stopping patience (epochs)
|
| 161 |
+
seed: int = 42
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
+
# 4. LR scheduler with linear warmup + cosine decay
|
| 166 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 167 |
+
def _build_scheduler(optimizer, warmup_steps: int, total_steps: int):
|
| 168 |
+
"""Linear warmup for `warmup_steps`, then cosine decay to 0."""
|
| 169 |
+
|
| 170 |
+
def lr_lambda(step):
|
| 171 |
+
if step < warmup_steps:
|
| 172 |
+
return step / max(1, warmup_steps)
|
| 173 |
+
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
| 174 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 175 |
+
|
| 176 |
+
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 180 |
+
# 5. Single-epoch training
|
| 181 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 182 |
+
def train_one_epoch(
|
| 183 |
+
model: nn.Module,
|
| 184 |
+
loader: DataLoader,
|
| 185 |
+
optimizer: torch.optim.Optimizer,
|
| 186 |
+
scheduler,
|
| 187 |
+
criterion: nn.Module,
|
| 188 |
+
device: torch.device,
|
| 189 |
+
scaler: Optional[torch.amp.GradScaler],
|
| 190 |
+
grad_clip: float = 1.0,
|
| 191 |
+
log_every: int = 10,
|
| 192 |
+
epoch: int = 0,
|
| 193 |
+
) -> float:
|
| 194 |
+
"""Train for one epoch. Returns average loss."""
|
| 195 |
+
model.train()
|
| 196 |
+
total_loss = 0.0
|
| 197 |
+
n_tokens = 0
|
| 198 |
+
|
| 199 |
+
for step, batch in enumerate(loader):
|
| 200 |
+
src = batch["src"].to(device)
|
| 201 |
+
tgt = batch["tgt"].to(device)
|
| 202 |
+
label = batch["label"].to(device)
|
| 203 |
+
|
| 204 |
+
optimizer.zero_grad()
|
| 205 |
+
|
| 206 |
+
amp_enabled = scaler is not None
|
| 207 |
+
with torch.amp.autocast("cuda", enabled=amp_enabled):
|
| 208 |
+
logits = model(src, tgt) # (B, T, V)
|
| 209 |
+
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
|
| 210 |
+
|
| 211 |
+
if scaler is not None:
|
| 212 |
+
scaler.scale(loss).backward()
|
| 213 |
+
scaler.unscale_(optimizer)
|
| 214 |
+
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 215 |
+
scaler.step(optimizer)
|
| 216 |
+
scaler.update()
|
| 217 |
+
else:
|
| 218 |
+
loss.backward()
|
| 219 |
+
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
| 220 |
+
optimizer.step()
|
| 221 |
+
|
| 222 |
+
scheduler.step()
|
| 223 |
+
|
| 224 |
+
# Accumulate loss (ignore padding contribution)
|
| 225 |
+
non_pad = (label != model.pad_idx).sum().item()
|
| 226 |
+
total_loss += loss.item() * non_pad
|
| 227 |
+
n_tokens += non_pad
|
| 228 |
+
|
| 229 |
+
if (step + 1) % log_every == 0:
|
| 230 |
+
avg = total_loss / max(n_tokens, 1)
|
| 231 |
+
lr = scheduler.get_last_lr()[0]
|
| 232 |
+
print(f" Epoch {epoch+1} | Step {step+1}/{len(loader)} | Loss {avg:.4f} | LR {lr:.2e}")
|
| 233 |
+
|
| 234 |
+
return total_loss / max(n_tokens, 1)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
+
# 6. Validation loss
|
| 239 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def evaluate_loss(
|
| 242 |
+
model: nn.Module,
|
| 243 |
+
loader: DataLoader,
|
| 244 |
+
criterion: nn.Module,
|
| 245 |
+
device: torch.device,
|
| 246 |
+
use_amp: bool = False,
|
| 247 |
+
) -> float:
|
| 248 |
+
"""Compute average loss over a validation set (with AMP to match training)."""
|
| 249 |
+
model.eval()
|
| 250 |
+
total_loss = 0.0
|
| 251 |
+
n_tokens = 0
|
| 252 |
+
n_batches = len(loader)
|
| 253 |
+
|
| 254 |
+
for step, batch in enumerate(loader):
|
| 255 |
+
src = batch["src"].to(device)
|
| 256 |
+
tgt = batch["tgt"].to(device)
|
| 257 |
+
label = batch["label"].to(device)
|
| 258 |
+
|
| 259 |
+
with torch.amp.autocast("cuda", enabled=use_amp):
|
| 260 |
+
logits = model(src, tgt)
|
| 261 |
+
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
|
| 262 |
+
|
| 263 |
+
non_pad = (label != model.pad_idx).sum().item()
|
| 264 |
+
total_loss += loss.item() * non_pad
|
| 265 |
+
n_tokens += non_pad
|
| 266 |
+
|
| 267 |
+
if (step + 1) % max(1, n_batches // 4) == 0 or (step + 1) == n_batches:
|
| 268 |
+
print(f" Val {step+1}/{n_batches}", end="\r")
|
| 269 |
+
|
| 270 |
+
return total_loss / max(n_tokens, 1)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 274 |
+
# 7. Full training driver
|
| 275 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 276 |
+
def train(
|
| 277 |
+
model: nn.Module,
|
| 278 |
+
train_loader: DataLoader,
|
| 279 |
+
val_loader: DataLoader,
|
| 280 |
+
cfg: TrainConfig,
|
| 281 |
+
device: torch.device,
|
| 282 |
+
trial=None,
|
| 283 |
+
) -> dict:
|
| 284 |
+
"""
|
| 285 |
+
Full training loop with logging, checkpointing, and early stopping.
|
| 286 |
+
|
| 287 |
+
Parameters
|
| 288 |
+
----------
|
| 289 |
+
trial : optuna.trial.Trial, optional
|
| 290 |
+
If provided, reports val_loss after each epoch for ASHA pruning.
|
| 291 |
+
|
| 292 |
+
Returns
|
| 293 |
+
-------
|
| 294 |
+
history : dict
|
| 295 |
+
``{"train_loss": [...], "val_loss": [...], "lr": [...]}``
|
| 296 |
+
"""
|
| 297 |
+
# --- Loss function (label-smoothed CE, ignoring PAD) ---------------
|
| 298 |
+
criterion = nn.CrossEntropyLoss(
|
| 299 |
+
ignore_index=model.pad_idx,
|
| 300 |
+
label_smoothing=cfg.label_smoothing,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# --- Optimiser ------------------------------------------------------
|
| 304 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.98), eps=1e-9)
|
| 305 |
+
|
| 306 |
+
# --- LR schedule ---------------------------------------------------
|
| 307 |
+
total_steps = cfg.epochs * len(train_loader)
|
| 308 |
+
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
|
| 309 |
+
|
| 310 |
+
# --- AMP scaler ----------------------------------------------------
|
| 311 |
+
scaler = torch.amp.GradScaler("cuda") if (cfg.use_amp and device.type == "cuda") else None
|
| 312 |
+
|
| 313 |
+
# --- Checkpoint dir ------------------------------------------------
|
| 314 |
+
ckpt_dir = Path(cfg.checkpoint_dir)
|
| 315 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 316 |
+
|
| 317 |
+
history: dict = {"train_loss": [], "val_loss": [], "lr": []}
|
| 318 |
+
best_val = float("inf")
|
| 319 |
+
patience_ctr = 0
|
| 320 |
+
|
| 321 |
+
print(f"\n{'='*60}")
|
| 322 |
+
print(f"Starting training: {cfg.epochs} epochs, lr={cfg.lr}, AMP={cfg.use_amp}")
|
| 323 |
+
print(f"{'='*60}\n")
|
| 324 |
+
|
| 325 |
+
for epoch in range(cfg.epochs):
|
| 326 |
+
t0 = time.time()
|
| 327 |
+
|
| 328 |
+
train_loss = train_one_epoch(
|
| 329 |
+
model, train_loader, optimizer, scheduler, criterion,
|
| 330 |
+
device, scaler, cfg.grad_clip, cfg.log_every, epoch,
|
| 331 |
+
)
|
| 332 |
+
use_amp = cfg.use_amp and device.type == "cuda"
|
| 333 |
+
val_loss = evaluate_loss(model, val_loader, criterion, device, use_amp=use_amp)
|
| 334 |
+
lr = scheduler.get_last_lr()[0]
|
| 335 |
+
|
| 336 |
+
elapsed = time.time() - t0
|
| 337 |
+
history["train_loss"].append(train_loss)
|
| 338 |
+
history["val_loss"].append(val_loss)
|
| 339 |
+
history["lr"].append(lr)
|
| 340 |
+
|
| 341 |
+
print(
|
| 342 |
+
f"Epoch {epoch+1}/{cfg.epochs} | "
|
| 343 |
+
f"Train {train_loss:.4f} | Val {val_loss:.4f} | "
|
| 344 |
+
f"LR {lr:.2e} | {elapsed:.1f}s"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# --- Optuna ASHA pruning (if trial provided) ------------------
|
| 348 |
+
if trial is not None:
|
| 349 |
+
import optuna
|
| 350 |
+
trial.report(val_loss, epoch)
|
| 351 |
+
if trial.should_prune():
|
| 352 |
+
print(f"\nβ Optuna pruned this trial at epoch {epoch+1}.")
|
| 353 |
+
raise optuna.TrialPruned()
|
| 354 |
+
|
| 355 |
+
# --- Checkpoint best model ------------------------------------
|
| 356 |
+
if val_loss < best_val:
|
| 357 |
+
best_val = val_loss
|
| 358 |
+
patience_ctr = 0
|
| 359 |
+
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
|
| 360 |
+
print(f" β³ New best val loss β checkpoint saved.")
|
| 361 |
+
else:
|
| 362 |
+
patience_ctr += 1
|
| 363 |
+
if patience_ctr >= cfg.patience:
|
| 364 |
+
print(f"\nβΉ Early stopping after {cfg.patience} epochs without improvement.")
|
| 365 |
+
break
|
| 366 |
+
|
| 367 |
+
# Load best checkpoint
|
| 368 |
+
model.load_state_dict(torch.load(ckpt_dir / "best_model.pt", map_location=device, weights_only=True))
|
| 369 |
+
print(f"\nβ Training complete. Best val loss: {best_val:.4f}")
|
| 370 |
+
return history
|
tokenizer_shared_16k.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|