AlainDeLong commited on
Commit
e27ab6a
·
1 Parent(s): 1a0fc46

Create translate app

Browse files
Files changed (14) hide show
  1. .gitignore +227 -0
  2. Dockerfile +9 -1
  3. requirements.txt +4 -1
  4. src/callbacks.py +24 -0
  5. src/config.py +70 -0
  6. src/dataset.py +280 -0
  7. src/embedding.py +105 -0
  8. src/engine.py +278 -0
  9. src/layers.py +186 -0
  10. src/model.py +207 -0
  11. src/modules.py +323 -0
  12. src/streamlit_app.py +176 -38
  13. src/tokenizer.py +156 -0
  14. src/utils.py +375 -0
.gitignore ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
208
+
209
+ # # Image
210
+ # images/
211
+
212
+ # Dataset
213
+ data/en-vi.txt/
214
+ data/IWSLT'15 en-vi/
215
+ notebooks/processed_data/
216
+ notebooks/IWSLT-15-en-vi/
217
+
218
+ # MLflow
219
+ mlruns/
220
+
221
+ # Temp Files
222
+ scratch/
223
+ notebooks/
224
+ test_push_to_hub.ipynb
225
+
226
+ # Weights & Biases
227
+ wandb/
Dockerfile CHANGED
@@ -13,8 +13,16 @@ COPY src/ ./src/
13
 
14
  RUN pip3 install -r requirements.txt
15
 
 
 
 
 
 
 
 
16
  EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
 
13
 
14
  RUN pip3 install -r requirements.txt
15
 
16
+ RUN mkdir -p /app/hf_cache
17
+
18
+ ENV HF_HOME="/app/hf_cache"
19
+
20
+ RUN chmod -R 777 /app
21
+
22
+
23
  EXPOSE 8501
24
 
25
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
26
 
27
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
28
+ # ENTRYPOINT ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
1
  altair
2
  pandas
3
+ streamlit
4
+ torch==2.6.0
5
+ transformers==4.52.4
6
+ jaxtyping
src/callbacks.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class EarlyStopping:
2
+ def __init__(self, patience=5, min_delta=1e-4, verbose=True):
3
+ self.patience = patience
4
+ self.min_delta = min_delta
5
+ self.verbose = verbose
6
+ self.best_loss = float("inf")
7
+ self.counter = 0
8
+ self.should_stop = False
9
+
10
+ def step(self, val_loss):
11
+ # Check improvement
12
+ if val_loss < self.best_loss - self.min_delta:
13
+ self.best_loss = val_loss
14
+ self.counter = 0
15
+ else:
16
+ self.counter += 1
17
+
18
+ # Stop condition
19
+ if self.counter >= self.patience:
20
+ self.should_stop = True
21
+ if self.verbose:
22
+ print(
23
+ f"[EarlyStopping] No improvement for {self.patience} epochs → stopping."
24
+ )
src/config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+
4
+ # Path Configuration
5
+ DATA_PATH = Path(r"data\IWSLT-15-en-vi")
6
+
7
+ # TOKENIZER_NAME = ""
8
+ # TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
9
+ TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
10
+ TOKENIZER_PATH = Path(r"artifacts\tokenizers") / TOKENIZER_NAME
11
+
12
+ MODEL_DIR = Path(r"artifacts\models")
13
+
14
+ # MODEL_NAME = ""
15
+ # MODEL_NAME = "transformer_en_vi_iwslt_1.pt"
16
+ MODEL_NAME = "transformer_en_vi_iwslt_1.safetensors"
17
+
18
+ # MODEL_SAVE_PATH = MODEL_DIR / MODEL_NAME
19
+ MODEL_SAVE_PATH = MODEL_DIR / "transformer_en_vi_iwslt_kaggle_1.safetensors"
20
+ # MODEL_SAVE_PATH = Path(r"notebooks\models") / MODEL_NAME
21
+
22
+ CHECKPOINT_PATH = Path(r"artifacts\checkpoints") / MODEL_NAME
23
+
24
+ CACHE_DIR = ""
25
+
26
+
27
+ # Hardware & Data Config
28
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ NUM_WORKERS: int = 4
31
+
32
+ VOCAB_SIZE: int = 32_000
33
+
34
+ SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]
35
+
36
+ NUM_SAMPLES_TO_USE: int = 1000
37
+ # NUM_SAMPLES_TO_USE: int = 1_000_000
38
+
39
+
40
+ # Tokenizer Constants
41
+ PAD_TOKEN_ID: int = 0
42
+ UNK_TOKEN_ID: int = 1
43
+ SOS_TOKEN_ID: int = 2
44
+ EOS_TOKEN_ID: int = 3
45
+
46
+
47
+ # Model Hyperparameters
48
+ # D_MODEL: int = 256 # (Dimension of model)
49
+ D_MODEL: int = 512
50
+ N_LAYERS: int = 6 # (N=6 in paper)
51
+ N_HEADS: int = 8 # (h=8 in paper)
52
+ # D_FF: int = 1024 # (d_ff = 4 * d_model = 1024)
53
+ D_FF: int = 2048
54
+ DROPOUT: float = 0.1 # (Dropout = 0.1 in paper)
55
+ MAX_SEQ_LEN: int = 150 # (Max length for Positional Encoding)
56
+
57
+
58
+ # Training Configuration
59
+ # LEARNING_RATE: float = 1e-4
60
+ LEARNING_RATE: float = 5e-4
61
+ BATCH_SIZE: int = 32
62
+ EPOCHS: int = 5
63
+ # EPOCHS: int = 50
64
+
65
+ # HuggingFace
66
+ REPO_ID: str = "AlainDeLong/transformer-en-vi-base"
67
+ FILENAME: str = "transformer_en_vi_iwslt_kaggle_1.safetensors"
68
+
69
+ if __name__ == "__main__":
70
+ print(f"Using device: {DEVICE}")
src/dataset.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from datasets import Dataset as ArrowDataset
6
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
7
+
8
+ import config
9
+ from src import utils
10
+
11
+
12
+ class TranslationDataset(Dataset):
13
+ """
14
+ A "lazy" Dataset.
15
+ Uses the high-level PreTrainedTokenizerFast wrapper.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dataset: ArrowDataset,
21
+ tokenizer: PreTrainedTokenizerFast,
22
+ max_len_src: int,
23
+ max_len_tgt: int,
24
+ src_lang: str = "en",
25
+ tgt_lang: str = "vi",
26
+ ):
27
+ super().__init__()
28
+ self.dataset = dataset
29
+ self.tokenizer = tokenizer
30
+ self.max_len_src = max_len_src
31
+ self.max_len_tgt = max_len_tgt
32
+ self.src_lang = src_lang
33
+ self.tgt_lang = tgt_lang
34
+
35
+ def __len__(self) -> int:
36
+ return len(self.dataset)
37
+
38
+ def __getitem__(self, index: int) -> dict[str, list[int]]:
39
+
40
+ item = self.dataset[index]["translation"]
41
+ src_text = item[self.src_lang]
42
+ tgt_text = item[self.tgt_lang]
43
+
44
+ # We set add_special_tokens=False for manual control.
45
+ src_encoding = self.tokenizer(
46
+ src_text,
47
+ truncation=True,
48
+ max_length=self.max_len_src,
49
+ add_special_tokens=False, # (Source has no SOS/EOS)
50
+ )
51
+
52
+ tgt_encoding = self.tokenizer(
53
+ tgt_text,
54
+ truncation=True,
55
+ max_length=self.max_len_tgt - 2, # (Reserve 2 spots for SOS/EOS)
56
+ add_special_tokens=False,
57
+ )
58
+
59
+ # Manually add SOS/EOS to target
60
+ src_ids = src_encoding["input_ids"]
61
+
62
+ tgt_ids = (
63
+ [config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID]
64
+ )
65
+
66
+ return {"src_ids": src_ids, "tgt_ids": tgt_ids}
67
+
68
+
69
+ class DataCollator:
70
+ """
71
+ Implements a custom collate_fn.
72
+
73
+ 1. Takes a list of dicts (from __getitem__)
74
+ 2. Adds SOS/EOS (Wait, we did this in Dataset)
75
+ 3. Creates decoder inputs and labels (shifted)
76
+ 4. Dynamically pads all sequences *in the batch*
77
+ 5. Creates all 3 required masks
78
+ 6. Returns a single dict of tensors
79
+ """
80
+
81
+ def __init__(self, pad_token_id: int):
82
+ self.pad_token_id = pad_token_id
83
+
84
+ def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]:
85
+
86
+ # 1. Get raw ID lists from the batch
87
+ src_ids_list = [item["src_ids"] for item in batch]
88
+ tgt_ids_list = [item["tgt_ids"] for item in batch] # (Already has SOS/EOS)
89
+
90
+ # 2. Create shifted inputs/labels
91
+ # Decoder input (T_tgt): [SOS, w1, w2, w3]
92
+ dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list]
93
+ # Label (T_tgt): [w1, w2, w3, EOS]
94
+ labels_list = [ids[1:] for ids in tgt_ids_list]
95
+
96
+ # 3. Dynamic Padding
97
+ # We use torch.nn.utils.rnn.pad_sequence
98
+ # (Note: batch_first=True means (B, T))
99
+ src_ids_padded = nn.utils.rnn.pad_sequence(
100
+ [torch.tensor(ids) for ids in src_ids_list],
101
+ batch_first=True,
102
+ padding_value=self.pad_token_id,
103
+ )
104
+
105
+ dec_input_ids_padded = nn.utils.rnn.pad_sequence(
106
+ [torch.tensor(ids) for ids in dec_input_ids_list],
107
+ batch_first=True,
108
+ padding_value=self.pad_token_id,
109
+ )
110
+
111
+ labels_padded = nn.utils.rnn.pad_sequence(
112
+ [torch.tensor(ids) for ids in labels_list],
113
+ batch_first=True,
114
+ padding_value=self.pad_token_id, # (Loss will ignore this ID)
115
+ )
116
+
117
+ # 4. Get the sequence length
118
+ _, T_tgt = dec_input_ids_padded.shape
119
+
120
+ # 5. Create Masks (on CPU)
121
+
122
+ # (Mask 1) Source padding mask (for Encoder MHA & Cross-Attn)
123
+ # Shape: (B, 1, 1, T_src)
124
+ src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id)
125
+
126
+ # (Mask 2) Target padding mask (for Decoder MHA)
127
+ # Shape: (B, 1, 1, T_tgt)
128
+ tgt_padding_mask = utils.create_padding_mask(
129
+ dec_input_ids_padded, self.pad_token_id
130
+ )
131
+
132
+ # (Mask 3) Target look-ahead mask (for Decoder MHA)
133
+ # Shape: (1, 1, T_tgt, T_tgt)
134
+ look_ahead_mask = utils.create_look_ahead_mask(T_tgt)
135
+
136
+ # (Mask 4) Combined target mask
137
+ # Shape: (B, 1, T_tgt, T_tgt)
138
+ tgt_mask = tgt_padding_mask & look_ahead_mask
139
+
140
+ return {
141
+ "src_ids": src_ids_padded, # (B, T_src)
142
+ "tgt_input_ids": dec_input_ids_padded, # (B, T_tgt)
143
+ "labels": labels_padded, # (B, T_tgt)
144
+ "src_mask": src_mask, # (B, 1, 1, T_src)
145
+ "tgt_mask": tgt_mask, # (B, 1, T_tgt, T_tgt)
146
+ }
147
+
148
+
149
+ def get_translation_datasets(
150
+ tokenizer: PreTrainedTokenizerFast,
151
+ ) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]:
152
+ """
153
+ A Factory function to automate the data pipeline setup.
154
+
155
+ It performs 3 steps:
156
+ 1. Loads and cleans raw data (using src.utils).
157
+ 2. Instantiates the TranslationDataset for Train, Val, and Test splits.
158
+ 3. Returns the 3 PyTorch datasets ready for the DataLoader.
159
+
160
+ Args:
161
+ tokenizer: The trained tokenizer.
162
+
163
+ Returns:
164
+ Tuple containing (train_ds, val_ds, test_ds)
165
+ """
166
+
167
+ # 1. Load raw cleaned data (returns Dict[str, Dataset])
168
+ # This keeps train.py clean from raw data handling logic.
169
+ train_data, val_data, test_data = utils.get_raw_data(
170
+ config.DATA_PATH, num_workers=config.NUM_WORKERS
171
+ )
172
+ train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE))
173
+
174
+ print(f"Building PyTorch Datasets...")
175
+
176
+ # 2. Instantiate the Train Dataset
177
+ # (Uses global config for max_length)
178
+ train_ds = TranslationDataset(
179
+ dataset=train_data,
180
+ tokenizer=tokenizer,
181
+ max_len_src=config.MAX_SEQ_LEN,
182
+ max_len_tgt=config.MAX_SEQ_LEN,
183
+ )
184
+
185
+ # 3. Instantiate the Validation Dataset
186
+ val_ds = TranslationDataset(
187
+ dataset=val_data,
188
+ tokenizer=tokenizer,
189
+ max_len_src=config.MAX_SEQ_LEN,
190
+ max_len_tgt=config.MAX_SEQ_LEN,
191
+ )
192
+
193
+ # 4. Instantiate the Test Dataset
194
+ test_ds = TranslationDataset(
195
+ dataset=test_data,
196
+ tokenizer=tokenizer,
197
+ max_len_src=config.MAX_SEQ_LEN,
198
+ max_len_tgt=config.MAX_SEQ_LEN,
199
+ )
200
+
201
+ print(
202
+ f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}"
203
+ )
204
+
205
+ return train_ds, val_ds, test_ds
206
+
207
+
208
+ def get_dataloaders(
209
+ tokenizer: PreTrainedTokenizerFast,
210
+ ) -> tuple[DataLoader, DataLoader, DataLoader]:
211
+ """
212
+ A high-level Factory function to create DataLoaders.
213
+
214
+ This function abstracts away all the data pipeline complexity:
215
+ - Loading/Cleaning raw data
216
+ - Creating PyTorch Datasets
217
+ - Instantiating the DataCollator (dynamic padding)
218
+ - Creating DataLoaders with the correct batch size and workers
219
+
220
+ Args:
221
+ tokenizer: The trained tokenizer.
222
+
223
+ Returns:
224
+ Tuple containing (train_loader, val_loader, test_loader)
225
+ """
226
+
227
+ # 1. Create the Datasets (using the factory function we made earlier)
228
+ train_ds, val_ds, test_ds = get_translation_datasets(tokenizer)
229
+
230
+ # 2. Instantiate the Collator
231
+ # (We need config to get PAD_TOKEN_ID)
232
+ collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID)
233
+
234
+ print(
235
+ f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..."
236
+ )
237
+
238
+ # 3. Create Train DataLoader
239
+ # (Shuffle = True is CRITICAL for training)
240
+ train_loader = DataLoader(
241
+ train_ds,
242
+ batch_size=config.BATCH_SIZE,
243
+ shuffle=True,
244
+ num_workers=config.NUM_WORKERS,
245
+ collate_fn=collator,
246
+ pin_memory=True if config.DEVICE == "cuda" else False, # (Optimization)
247
+ prefetch_factor=2,
248
+ persistent_workers=True,
249
+ )
250
+
251
+ # 4. Create Validation DataLoader
252
+ # (Shuffle = False for reproducible validation)
253
+ val_loader = DataLoader(
254
+ val_ds,
255
+ batch_size=2 * config.BATCH_SIZE,
256
+ shuffle=False,
257
+ num_workers=config.NUM_WORKERS,
258
+ collate_fn=collator,
259
+ pin_memory=True if config.DEVICE == "cuda" else False,
260
+ prefetch_factor=2,
261
+ persistent_workers=True,
262
+ )
263
+
264
+ # 5. Create Test DataLoader
265
+ test_loader = DataLoader(
266
+ test_ds,
267
+ batch_size=2 * config.BATCH_SIZE,
268
+ shuffle=False,
269
+ num_workers=2,
270
+ # num_workers=config.NUM_WORKERS,
271
+ collate_fn=collator,
272
+ pin_memory=True if config.DEVICE == "cuda" else False,
273
+ prefetch_factor=2,
274
+ )
275
+
276
+ print(f"DataLoader (train) created with {len(train_loader)} batches.")
277
+ print(f"DataLoader (val) created with {len(val_loader)} batches.")
278
+ print(f"DataLoader (test) created with {len(test_loader)} batches.")
279
+
280
+ return train_loader, val_loader, test_loader
src/embedding.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from jaxtyping import Int, Float
5
+ import math
6
+
7
+
8
+ class InputEmbeddings(nn.Module):
9
+ """
10
+ Implements the Input Embedding layer.
11
+
12
+ This module converts a tensor of token IDs into a tensor of
13
+ corresponding embedding vectors. It also scales the embeddings
14
+ by sqrt(d_model) as mentioned in the paper ("Attention Is All You Need",
15
+ Section 3.4).
16
+ """
17
+
18
+ def __init__(self, d_model: int, vocab_size: int) -> None:
19
+ """
20
+ Initializes the InputEmbedding layer.
21
+
22
+ Args:
23
+ d_model (int): The dimension of the embedding vector (D).
24
+ vocab_size (int): The size of the vocabulary.
25
+ """
26
+ super().__init__()
27
+
28
+ self.d_model: int = d_model
29
+ self.vocab_size: int = vocab_size
30
+
31
+ self.token_emb: nn.Embedding = nn.Embedding(vocab_size, d_model)
32
+
33
+ def forward(self, x: Int[Tensor, "B T"]) -> Float[Tensor, "B T D"]:
34
+ """
35
+ Forward pass for the InputEmbeddings.
36
+
37
+ Args:
38
+ x (Tensor): Input tensor of token IDs. Shape (B, T). B: batch_size, T: seq_len
39
+
40
+ Returns:
41
+ Tensor: The corresponding embedding vectors, scaled by sqrt(d_model).
42
+ Shape (B, T, D).
43
+ """
44
+ # (B, T) -> (B, T, D)
45
+ embeddings = self.token_emb(x)
46
+
47
+ return embeddings * math.sqrt(self.d_model)
48
+
49
+
50
+ class PositionalEncoding(nn.Module):
51
+ """
52
+ Implements the fixed (sin/cos) Positional Encoding module.
53
+ (Ref: "Attention Is All You Need", Section 3.5)
54
+
55
+ This module generates a tensor of positional encodings that are
56
+ added to the input embeddings. It also applies dropout to the
57
+ sum of the embeddings and the positional encodings.
58
+ """
59
+
60
+ def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1) -> None:
61
+ """
62
+ Initializes the PositionalEncoding module.
63
+
64
+ Args:
65
+ d_model (int): The dimension of the model (D).
66
+ max_seq_len (int): The maximum sequence length (T_max) to pre-compute.
67
+ dropout (float): Dropout probability.
68
+ """
69
+ super().__init__()
70
+
71
+ self.dropout: nn.Dropout = nn.Dropout(p=dropout)
72
+
73
+ position: Tensor = torch.arange(max_seq_len).unsqueeze(1).float()
74
+
75
+ div_term: Tensor = torch.exp(
76
+ torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model)
77
+ )
78
+
79
+ # (T_max, D)
80
+ pe: Tensor = torch.zeros(max_seq_len, d_model)
81
+
82
+ pe[:, 0::2] = torch.sin(position * div_term)
83
+
84
+ pe[:, 1::2] = torch.cos(position * div_term)
85
+
86
+ # (T_max D) -> (1, T_max, D)
87
+ pe = pe.unsqueeze(0)
88
+
89
+ self.register_buffer("pe", pe)
90
+
91
+ def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]:
92
+ """
93
+ Adds positional encoding to the input embeddings and applies dropout.
94
+
95
+ Args:
96
+ x (Tensor): Input tensor (token embeddings, already scaled).
97
+ Shape (B, T, D).
98
+
99
+ Returns:
100
+ Tensor: Output tensor with positional information and dropout.
101
+ Shape (B, T, D).
102
+ """
103
+ x = x + self.pe[:, : x.size(1), :]
104
+
105
+ return self.dropout(x)
src/engine.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader
4
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
5
+ from torchmetrics.text import BLEUScore, SacreBLEUScore
6
+ from tqdm.auto import tqdm
7
+ import config
8
+ from src import model, utils
9
+
10
+
11
+ TGT_VOCAB_SIZE: int = config.VOCAB_SIZE
12
+
13
+
14
+ def train_one_epoch(
15
+ model: model.Transformer,
16
+ dataloader: DataLoader,
17
+ optimizer: torch.optim.Optimizer,
18
+ criterion: nn.Module,
19
+ scheduler: torch.optim.lr_scheduler.LambdaLR,
20
+ device: torch.device,
21
+ logger=None,
22
+ ) -> float:
23
+ """
24
+ Runs a single training epoch.
25
+
26
+ Args:
27
+ model: The Transformer model.
28
+ dataloader: The training DataLoader.
29
+ optimizer: The optimizer.
30
+ criterion: The loss function (e.g., CrossEntropyLoss).
31
+ device: The device to run on (e.g., 'cuda').
32
+
33
+ Returns:
34
+ The average training loss for the epoch.
35
+ """
36
+
37
+ # Set model to training mode
38
+ # This enables dropout, etc.
39
+ model.train()
40
+
41
+ total_loss = 0.0
42
+
43
+ # Use tqdm for a progress bar
44
+ progress_bar = tqdm(dataloader, desc="Training", leave=False)
45
+ batch_idx: int = 0
46
+
47
+ for batch in progress_bar:
48
+ batch_idx += 1
49
+
50
+ # 1. Move batch to device (GPU)
51
+ # We define a helper for this
52
+ batch_gpu = {
53
+ k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
54
+ }
55
+
56
+ # 2. Zero gradients before forward pass
57
+ optimizer.zero_grad()
58
+
59
+ # 3. Forward pass
60
+ # Get inputs for the model (as defined in Transformer.forward)
61
+ logits = model(
62
+ src=batch_gpu["src_ids"],
63
+ tgt=batch_gpu["tgt_input_ids"],
64
+ src_mask=batch_gpu["src_mask"],
65
+ tgt_mask=batch_gpu["tgt_mask"],
66
+ ) # Shape: (B, T_tgt, vocab_size)
67
+
68
+ # 4. Calculate loss
69
+ # CrossEntropyLoss expects (N, C) and (N,)
70
+ # We must reshape logits and labels
71
+ # Logits: (B, T_tgt, C) -> (B * T_tgt, C)
72
+ # Labels: (B, T_tgt) -> (B * T_tgt)
73
+ loss = criterion(logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1))
74
+
75
+ # 5. Backward pass (compute gradients)
76
+ loss.backward()
77
+
78
+ # 6. Gradient Clipping (from paper)
79
+ # Helps prevent exploding gradients. '1.0' is a common value.
80
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
81
+
82
+ # 7. Update weights
83
+ optimizer.step()
84
+
85
+ # 8. Update learning rate scheduler if used
86
+ scheduler.step()
87
+
88
+ # 9. Update stats
89
+ total_loss += loss.item()
90
+ progress_bar.set_postfix(loss=loss.item())
91
+
92
+ # 10. Log metrics
93
+ if logger and batch_idx % 100 == 0:
94
+ logger.log(
95
+ {
96
+ "train/batch_loss": loss.item(),
97
+ "train/learning_rate": optimizer.param_groups[0]["lr"],
98
+ }
99
+ )
100
+
101
+ # Return average loss for the epoch
102
+ return total_loss / len(dataloader)
103
+
104
+
105
+ def validate_one_epoch(
106
+ model: model.Transformer,
107
+ dataloader: DataLoader,
108
+ criterion: nn.Module,
109
+ device: torch.device,
110
+ ) -> float:
111
+ """
112
+ Runs a single validation epoch.
113
+
114
+ Args:
115
+ model: The Transformer model.
116
+ dataloader: The validation DataLoader.
117
+ criterion: The loss function (e.g., CrossEntropyLoss).
118
+ device: The device to run on (e.g., 'cuda').
119
+
120
+ Returns:
121
+ The average validation loss for the epoch.
122
+ """
123
+
124
+ # Set model to evaluation mode
125
+ # This disables dropout.
126
+ model.eval()
127
+
128
+ total_loss = 0.0
129
+
130
+ # Use tqdm for a progress bar
131
+ progress_bar = tqdm(dataloader, desc="Validating", leave=False)
132
+
133
+ # Disable gradient computation
134
+ # This saves VRAM and speeds up inference.
135
+ with torch.no_grad():
136
+ for batch in progress_bar:
137
+ # 1. Move batch to device (GPU)
138
+ batch_gpu = {
139
+ k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
140
+ }
141
+
142
+ # 2. Forward pass
143
+ logits = model(
144
+ src=batch_gpu["src_ids"],
145
+ tgt=batch_gpu["tgt_input_ids"],
146
+ src_mask=batch_gpu["src_mask"],
147
+ tgt_mask=batch_gpu["tgt_mask"],
148
+ ) # Shape: (B, T_tgt, vocab_size)
149
+
150
+ # 3. Calculate loss
151
+ # (Use the same reshaping as in training for consistency)
152
+ loss = criterion(
153
+ logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1)
154
+ )
155
+
156
+ # 4. Update stats
157
+ total_loss += loss.item()
158
+ progress_bar.set_postfix(loss=loss.item())
159
+
160
+ # Return average loss for the epoch
161
+ return total_loss / len(dataloader)
162
+
163
+
164
+ def evaluate_model(
165
+ model: model.Transformer,
166
+ dataloader: DataLoader,
167
+ tokenizer: PreTrainedTokenizerFast,
168
+ device: torch.device,
169
+ table=None,
170
+ ) -> tuple[float, float]:
171
+ """
172
+ Runs final evaluation on the test set using Beam Search
173
+ and calculates the SacreBLEU score.
174
+ """
175
+ print("\n--- Starting Evaluation (BLEU + SacreBLEU) ---")
176
+
177
+ # Set model to evaluation mode
178
+ # This disables dropout.
179
+ model.eval()
180
+
181
+ all_predicted_strings = []
182
+ all_expected_strings = []
183
+
184
+ # --- No gradients needed ---
185
+ with torch.no_grad():
186
+ for batch in tqdm(dataloader, desc="Evaluating"):
187
+
188
+ batch_gpu = {
189
+ k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
190
+ }
191
+
192
+ src_ids = batch_gpu["src_ids"]
193
+ src_mask = batch_gpu["src_mask"]
194
+ expected_ids = batch_gpu["labels"] # (B, T_tgt) [on GPU]
195
+
196
+ B = src_ids.size(0)
197
+
198
+ # --- Handle 2D Expected IDs) ---
199
+ batch_expected_strings = []
200
+
201
+ # Convert 2D GPU Tensor -> 2D CPU List
202
+ expected_id_lists = expected_ids.cpu().tolist()
203
+
204
+ # Now we iterate over the CPU list
205
+ for id_list in expected_id_lists:
206
+ # id_list is a 1D Python list (e.g., [70, 950, 7, 3])
207
+ # This call is now safe
208
+ token_list = tokenizer.convert_ids_to_tokens(id_list)
209
+ batch_expected_strings.append(
210
+ utils.filter_and_detokenize(token_list, skip_special=True)
211
+ )
212
+
213
+ # --- Generate (decode) one sentence at a time ---
214
+ batch_predicted_strings = []
215
+ for i in tqdm(range(B), desc="Decoding Batch", leave=False):
216
+ src_sentence = src_ids[i].unsqueeze(0)
217
+ src_sentence_mask = src_mask[i].unsqueeze(0)
218
+
219
+ # (predicted_ids is 1D Tensor [T_out] on GPU)
220
+ predicted_ids = utils.greedy_decode_sentence(
221
+ model,
222
+ src_sentence,
223
+ src_sentence_mask,
224
+ max_len=config.MAX_SEQ_LEN,
225
+ sos_token_id=config.SOS_TOKEN_ID,
226
+ eos_token_id=config.EOS_TOKEN_ID,
227
+ device=device,
228
+ )
229
+
230
+ # Convert 1D GPU Tensor -> 1D CPU List
231
+ predicted_id_list = predicted_ids.cpu().tolist()
232
+
233
+ # This call is now safe
234
+ predicted_token_list = tokenizer.convert_ids_to_tokens(
235
+ predicted_id_list
236
+ )
237
+
238
+ decoded_str = utils.filter_and_detokenize(
239
+ predicted_token_list, skip_special=True
240
+ )
241
+ batch_predicted_strings.append(decoded_str)
242
+
243
+ # --- Store strings for final metric calculation ---
244
+ all_predicted_strings.extend(batch_predicted_strings)
245
+ all_expected_strings.extend([[s] for s in batch_expected_strings])
246
+
247
+ bleu_metric = BLEUScore(n_gram=4, smooth=True).to(config.DEVICE)
248
+ sacrebleu_metric = SacreBLEUScore(
249
+ n_gram=4, smooth=True, tokenize="intl", lowercase=False
250
+ ).to(config.DEVICE)
251
+
252
+ # --- 5. Calculate final score ---
253
+ print("\nCalculating final BLEU score...")
254
+ final_bleu = bleu_metric(all_predicted_strings, all_expected_strings)
255
+
256
+ # print(f"\n========================================")
257
+ # print(f"🎉 FINAL BLEU SCORE (Evaluation Set): {final_bleu.item() * 100:.4f}%")
258
+ # print(f"========================================")
259
+
260
+ print("\nCalculating final SacreBLEU score...")
261
+ final_sacrebleu = sacrebleu_metric(all_predicted_strings, all_expected_strings)
262
+
263
+ # print(f"\n========================================")
264
+ # print(
265
+ # f"🎉 FINAL SacreBLEU SCORE (Evaluation Set): {final_sacrebleu.item() * 100:.4f}%"
266
+ # )
267
+ # print(f"========================================")
268
+
269
+ # --- Show some examples ---
270
+ print("\n--- Translation Examples (Pred vs Exp) ---")
271
+ for i in range(min(5, len(all_predicted_strings))):
272
+ print(f" PRED: {all_predicted_strings[i]}")
273
+ print(f" EXP: {all_expected_strings[i][0]}")
274
+ print(" ---")
275
+
276
+ table.add_data(all_expected_strings[i][0], all_predicted_strings[i])
277
+
278
+ return final_bleu.item() * 100, final_sacrebleu.item() * 100
src/layers.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ import torch.nn as nn
3
+ from jaxtyping import Bool, Float
4
+ import math
5
+
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ """
9
+ Terminology (jaxtyping):
10
+ B: batch_size
11
+ T_q: target sequence length (query)
12
+ T_k: source sequence length (key/value)
13
+ D: d_model (model dimension)
14
+ H: n_heads (number of heads)
15
+ d_k: dimension of each head (d_model / n_heads)
16
+ """
17
+
18
+ def __init__(self, d_model: int, n_heads: int) -> None:
19
+ super().__init__()
20
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
21
+
22
+ self.d_model: int = d_model
23
+ self.n_heads: int = n_heads
24
+ self.d_k: int = d_model // n_heads
25
+
26
+ self.w_q: nn.Linear = nn.Linear(d_model, d_model, bias=False)
27
+ self.w_k: nn.Linear = nn.Linear(d_model, d_model, bias=False)
28
+ self.w_v: nn.Linear = nn.Linear(d_model, d_model, bias=False)
29
+ self.w_o: nn.Linear = nn.Linear(d_model, d_model, bias=False)
30
+
31
+ self.attention_weights: Tensor | None = None
32
+
33
+ @staticmethod
34
+ def attention(
35
+ query: Float[Tensor, "B H T_q d_k"],
36
+ key: Float[Tensor, "B H T_k d_k"],
37
+ value: Float[Tensor, "B H T_k d_k"],
38
+ mask: Bool[Tensor, "... 1 T_q T_k"] | None,
39
+ ) -> tuple[Float[Tensor, "B H T_q d_k"], Float[Tensor, "B H T_q T_k"]]:
40
+ """
41
+ Static method for Scaled Dot-Product Attention calculation.
42
+ This is pure, stateless logic, making it easy to test.
43
+ (Ref: "Attention Is All You Need", Equation 1)
44
+
45
+ Args:
46
+ query (Tensor): Query tensor
47
+ key (Tensor): Key tensor
48
+ value (Tensor): Value tensor
49
+ mask (Tensor | None): Optional mask (for padding or look-ahead).
50
+
51
+ Returns:
52
+ tuple[Tensor, Tensor]:
53
+ - context_vector: The output of the attention mechanism.
54
+ - attention_weights: The softmax-normalized attention weights.
55
+ """
56
+
57
+ d_k: int = query.shape[-1]
58
+
59
+ # (B, H, T_q, d_k) @ (B, H, d_k, T_k) -> (B, H, T_q, T_k)
60
+ attention_scores: Tensor = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
61
+
62
+ if mask is not None:
63
+ attention_scores = attention_scores.masked_fill(
64
+ mask == 0, value=float("-inf")
65
+ )
66
+
67
+ attention_weights: Tensor = attention_scores.softmax(dim=-1)
68
+
69
+ # (B, H, T_q, T_k) @ (B, H, T_k, d_k) -> (B, H, T_q, d_k)
70
+ context_vector: Tensor = attention_weights @ value
71
+
72
+ return context_vector, attention_weights
73
+
74
+ def forward(
75
+ self,
76
+ q: Float[Tensor, "B T_q D"],
77
+ k: Float[Tensor, "B T_k D"],
78
+ v: Float[Tensor, "B T_k D"],
79
+ mask: Bool[Tensor, "... 1 T_q T_k"] | None = None, # Optional mask
80
+ ) -> Float[Tensor, "B T_q D"]:
81
+ """
82
+ Forward pass for Multi-Head Attention.
83
+
84
+ In Self-Attention (Encoder), q, k, and v are all the same tensor.
85
+ In Cross-Attention (Decoder), q comes from the Decoder, while k and v
86
+ come from the Encoder's output.
87
+
88
+ Args:
89
+ q: Query tensor
90
+ k: Key tensor
91
+ v: Value tensor
92
+ mask: Optional mask to apply (padding or look-ahead)
93
+
94
+ Returns:
95
+ The context vector after multi-head attention and output projection.
96
+ """
97
+
98
+ B, T_q, _ = q.shape
99
+ _, T_k, _ = k.shape # T_k == T_v
100
+
101
+ # (B, T, D) -> (B, T, D)
102
+ Q: Tensor = self.w_q(q)
103
+ K: Tensor = self.w_k(k)
104
+ V: Tensor = self.w_v(v)
105
+
106
+ # (B, T, D) -> (B, T, H, d_k) -> (B, H, T, d_k)
107
+ Q = Q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2)
108
+ K = K.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)
109
+ V = V.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)
110
+
111
+ context_vector, self.attention_weights = self.attention(Q, K, V, mask)
112
+
113
+ # (B, H, T_q, d_k) -> (B, T_q, H, d_k)
114
+ context_vector = context_vector.transpose(1, 2).contiguous()
115
+
116
+ # (B, T_q, H, d_k) -> (B, T_q, D)
117
+ context_vector = context_vector.view(B, T_q, self.d_model)
118
+
119
+ # (B, T_q, D) -> (B, T_q, D)
120
+ output: Tensor = self.w_o(context_vector)
121
+
122
+ return output
123
+
124
+
125
+ class PositionwiseFeedForward(nn.Module):
126
+ """
127
+ Implements the Position-wise Feed-Forward Network (FFN) sublayer.
128
+ (Ref: "Attention Is All You Need", Section 3.3)
129
+
130
+ This is a two-layer MLP (Multi-Layer Perceptron) applied independently
131
+ to each position in the sequence.
132
+
133
+ FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
134
+ (Or using ReLU activation)
135
+
136
+ Terminology (jaxtyping):
137
+ B: batch_size
138
+ T: seq_len (context_length)
139
+ D: d_model (model dimension)
140
+ D_FF: d_ff (inner feed-forward dimension)
141
+ """
142
+
143
+ def __init__(self, d_model: int, d_ff: int) -> None:
144
+ """
145
+ Initializes the FFN.
146
+
147
+ Args:
148
+ d_model (int): Dimension of the model (e.g., 512).
149
+ d_ff (int): Inner dimension of the FFN (e.g., 2048).
150
+ Paper suggests d_ff = 4 * d_model.
151
+ dropout (float): Dropout probability (applied *before* the
152
+ second linear layer in some implementations,
153
+ or as part of ResidualConnection).
154
+ """
155
+ super().__init__()
156
+
157
+ # (B, T, D) -> (B, T, D_FF)
158
+ self.linear_1: nn.Linear = nn.Linear(d_model, d_ff)
159
+
160
+ self.activation: nn.ReLU = nn.ReLU()
161
+
162
+ # (B, T, D_FF) -> (B, T, D)
163
+ self.linear_2: nn.Linear = nn.Linear(d_ff, d_model)
164
+
165
+ def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]:
166
+ """
167
+ Forward pass for the FFN.
168
+ Applies two linear transformations with a ReLU activation in between.
169
+
170
+ Args:
171
+ x: Input tensor from the previous sublayer
172
+ (e.g., MultiHeadAttention output).
173
+
174
+ Returns:
175
+ Output tensor of the same shape.
176
+ """
177
+ # (B, T, D) -> (B, T, D_FF)
178
+ x = self.linear_1(x)
179
+
180
+ # (B, T, D_FF) -> (B, T, D_FF)
181
+ x = self.activation(x)
182
+
183
+ # (B, T, D_FF) -> (B, T, D)
184
+ x = self.linear_2(x)
185
+
186
+ return x
src/model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ from safetensors.torch import load_model
5
+ from jaxtyping import Bool, Int, Float
6
+ from huggingface_hub import hf_hub_download
7
+ from embedding import InputEmbeddings, PositionalEncoding
8
+ from modules import Encoder, Decoder
9
+ import config
10
+
11
+
12
+ class Generator(nn.Module):
13
+ """
14
+ Implements the final Linear (Projection) layer and Softmax.
15
+
16
+ This module takes the final output of the Decoder stack (B, T, D)
17
+ and projects it onto the vocabulary space (B, T, vocab_size)
18
+ to produce the logits.
19
+
20
+ (This layer's weights can be tied with the
21
+ target embedding layer, which we will handle in the main
22
+ 'Transformer' model class).
23
+ """
24
+
25
+ def __init__(self, d_model: int, vocab_size: int) -> None:
26
+ """
27
+ Initializes the Generator (Output Projection) layer.
28
+
29
+ Args:
30
+ d_model (int): The dimension of the model (D).
31
+ vocab_size (int): The size of the target vocabulary.
32
+ """
33
+ super().__init__()
34
+
35
+ self.proj: nn.Linear = nn.Linear(d_model, vocab_size, bias=False)
36
+
37
+ def forward(
38
+ self, x: Float[Tensor, "B T_tgt D"]
39
+ ) -> Float[Tensor, "B T_tgt vocab_size"]:
40
+ """
41
+ Forward pass for the Generator.
42
+
43
+ Args:
44
+ x (Tensor): The final output tensor from the Decoder stack.
45
+
46
+ Returns:
47
+ Tensor: The output logits over the vocabulary.
48
+ """
49
+ # (B, T_tgt, D) -> (B, T_tgt, vocab_size)
50
+ logits = self.proj(x)
51
+ return logits
52
+
53
+
54
+ class Transformer(nn.Module):
55
+ """
56
+ The main Transformer model architecture, combining the Encoder
57
+ and Decoder stacks, as described in "Attention Is All You Need".
58
+
59
+ This implementation follows modern best practices (Pre-LN) and
60
+ is designed for a sequence-to-sequence task (e.g., translation).
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ src_vocab_size: int,
66
+ tgt_vocab_size: int,
67
+ d_model: int,
68
+ n_heads: int,
69
+ n_layers: int, # N=6 in the paper
70
+ d_ff: int,
71
+ dropout: float = 0.1,
72
+ max_seq_len: int = 512, # Max length for positional encoding
73
+ ) -> None:
74
+ """
75
+ Initializes the full Transformer model.
76
+
77
+ Args:
78
+ src_vocab_size (int): Vocabulary size for the source language.
79
+ tgt_vocab_size (int): Vocabulary size for the target language.
80
+ d_model (int): The dimension of the model (D).
81
+ n_heads (int): The number of attention heads (H).
82
+ n_layers (int): The number of Encoder/Decoder layers (N).
83
+ d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
84
+ dropout (float): The dropout rate.
85
+ max_seq_len (int): The maximum sequence length for positional encoding.
86
+ """
87
+ super().__init__()
88
+
89
+ self.d_model = d_model
90
+
91
+ # --- 1. Source (Encoder) Embeddings ---
92
+ # We create two separate embedding layers
93
+ self.src_embed: InputEmbeddings = InputEmbeddings(d_model, src_vocab_size)
94
+
95
+ # --- 2. Target (Decoder) Embeddings ---
96
+ self.tgt_embed: InputEmbeddings = InputEmbeddings(d_model, tgt_vocab_size)
97
+
98
+ # --- 3. Positional Encoding ---
99
+ # We use "one" PositionalEncoding module
100
+ # and share it for both source and target.
101
+ self.pos_enc: PositionalEncoding = PositionalEncoding(
102
+ d_model, max_seq_len, dropout
103
+ )
104
+
105
+ # --- 4. Encoder Stack ---
106
+ self.encoder: Encoder = Encoder(d_model, n_heads, d_ff, n_layers, dropout)
107
+
108
+ # --- 5. Decoder Stack ---
109
+ self.decoder: Decoder = Decoder(d_model, n_heads, d_ff, n_layers, dropout)
110
+
111
+ # --- 6. Final Output Projection (Generator) ---
112
+ self.generator: Generator = Generator(d_model, tgt_vocab_size)
113
+
114
+ # --- Weight Typing ---
115
+ # We tie the weights of the target embedding and the generator.
116
+ # This saves parameters and improves performance.
117
+ self.generator.proj.weight = self.tgt_embed.token_emb.weight
118
+
119
+ # --- Initialize weights ---
120
+ # This is crucial for stable training.
121
+ self.apply(self._init_weights)
122
+
123
+ def _init_weights(self, module: nn.Module):
124
+ """
125
+ Applies Xavier/Glorot uniform initialization to linear layers.
126
+ This is a common and effective initialization strategy.
127
+ """
128
+ if isinstance(module, nn.Linear):
129
+ nn.init.xavier_uniform_(module.weight)
130
+
131
+ if module.bias is not None:
132
+ nn.init.constant_(module.bias, 0)
133
+
134
+ elif isinstance(module, nn.Embedding):
135
+ # Initialize embeddings (e.g., from a normal distribution)
136
+ nn.init.normal_(module.weight, mean=0, std=self.d_model**-0.5)
137
+
138
+ def forward(
139
+ self,
140
+ src: Int[Tensor, "B T_src"], # Source token IDs (e.g., English)
141
+ tgt: Int[Tensor, "B T_tgt"], # Target token IDs (e.g., Vietnamese)
142
+ src_mask: Bool[Tensor, "B 1 1 T_src"], # Source padding mask
143
+ tgt_mask: Bool[Tensor, "B 1 T_tgt T_tgt"], # Target combined mask
144
+ ) -> Float[Tensor, "B T_tgt vocab_size"]:
145
+ """
146
+ Defines the main forward pass of the Transformer model.
147
+
148
+ Args:
149
+ src (Tensor): Source sequence token IDs.
150
+ tgt (Tensor): Target sequence token IDs (shifted right).
151
+ src_mask (Tensor): Padding mask for the source sequence.
152
+ tgt_mask (Tensor): Combined padding and look-ahead mask
153
+ for the target sequence.
154
+
155
+ Returns:
156
+ Tensor: The output logits from the model (B, T_tgt, vocab_size).
157
+ """
158
+ # 1. Encode the source sequence
159
+ # (B, T_src) -> (B, T_scr, D)
160
+ src_embeded = self.src_embed(src)
161
+ src_with_pos = self.pos_enc(src_embeded)
162
+
163
+ # (B, T_src, D) -> (B, T_src, D)
164
+ # This 'memory' will be used by every DecoderLayer
165
+ enc_output: Tensor = self.encoder(src_with_pos, src_mask)
166
+
167
+ # 2. Decode the target sequence
168
+ # (B, T_tgt) -> (B, T_tgt, D)
169
+ tgt_embeded = self.tgt_embed(tgt)
170
+ tgt_with_pos = self.pos_enc(tgt_embeded)
171
+
172
+ # (B, T_tgt, D) -> (B, T_tgt, D)
173
+ dec_output: Tensor = self.decoder(tgt_with_pos, enc_output, src_mask, tgt_mask)
174
+
175
+ # 3. Generate final logits
176
+ # (B, T_tgt, D) -> (B, T_tgt, vocab_size)
177
+ logits: Tensor = self.generator(dec_output)
178
+
179
+ return logits
180
+
181
+
182
+ def load_trained_model(
183
+ config_obj, checkpoint_path, device: torch.device
184
+ ) -> Transformer:
185
+ print("Downloading safetensors from Hub...")
186
+ model_path = hf_hub_download(repo_id=config.REPO_ID, filename=config.FILENAME)
187
+
188
+ print("Instantiating the Transformer model...")
189
+ model = Transformer(
190
+ src_vocab_size=config_obj.VOCAB_SIZE,
191
+ tgt_vocab_size=config_obj.VOCAB_SIZE,
192
+ d_model=config_obj.D_MODEL,
193
+ n_heads=config_obj.N_HEADS,
194
+ n_layers=config_obj.N_LAYERS,
195
+ d_ff=config_obj.D_FF,
196
+ dropout=config_obj.DROPOUT,
197
+ max_seq_len=config_obj.MAX_SEQ_LEN,
198
+ ).to(device)
199
+
200
+ # print(f"Loading model from: {checkpoint_path}")
201
+ # load_model(model, filename=checkpoint_path)
202
+
203
+ print(f"Loading model from: {model_path}")
204
+ load_model(model, filename=model_path)
205
+
206
+ print(f"Successfully loaded trained weights from {model_path}")
207
+ return model
src/modules.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ import torch.nn as nn
3
+ from typing import Callable
4
+ from jaxtyping import Bool, Float
5
+ from layers import MultiHeadAttention, PositionwiseFeedForward
6
+
7
+
8
+ class ResidualConnection(nn.Module):
9
+ """
10
+ Implements the (Pre-LN) Residual Connection module, which wraps a sublayer
11
+ (like MultiHeadAttention or FFN) with LayerNormalization and Dropout.
12
+
13
+ This is the modern "best practice" used in models like GPT-2, which is
14
+ more stable than the original Post-LN design in "Attention Is All You Need".
15
+
16
+ Architecture: x = x + Dropout(Sublayer(LayerNorm(x)))
17
+ """
18
+
19
+ def __init__(self, d_model: int, dropout: float = 0.1) -> None:
20
+ """
21
+ Initializes the Residual Connection.
22
+
23
+ Args:
24
+ d_model (int): The dimension of the model (D).
25
+ dropout (float): Dropout probability to apply to the sublayer output.
26
+ """
27
+ super().__init__()
28
+
29
+ self.dropout: nn.Dropout = nn.Dropout(dropout)
30
+
31
+ self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
32
+
33
+ def forward(
34
+ self,
35
+ x: Float[Tensor, "B T D"],
36
+ sublayer: Callable[[Float[Tensor, "B T D"]], Float[Tensor, "B T D"]],
37
+ ) -> Float[Tensor, "B T D"]:
38
+ """
39
+ Forward pass for the Residual Connection.
40
+
41
+ Args:
42
+ x (Tensor): The input tensor from the previous layer.
43
+ sublayer (Callable): The sublayer module (e.g., MHA or FFN)
44
+ to apply the connection to.
45
+
46
+ Returns:
47
+ Tensor: The output tensor after the residual connection.
48
+ """
49
+
50
+ x_normed = self.norm(x)
51
+
52
+ sublayer_output = sublayer(x_normed)
53
+
54
+ dropout_output = self.dropout(sublayer_output)
55
+
56
+ return x + dropout_output
57
+
58
+
59
+ class EncoderLayer(nn.Module):
60
+ """
61
+ Implements one single Encoder Layer (or "Block") of the Transformer Encoder.
62
+
63
+ An Encoder Layer consists of two main sublayers:
64
+ 1. A Multi-Head Self-Attention mechanism (MHA).
65
+ 2. A Position-wise Feed-Forward Network (FFN).
66
+
67
+ Each sublayer is wrapped by a ResidualConnection (which includes
68
+ Pre-LayerNormalization and Dropout).
69
+
70
+ Architecture:
71
+ x -> Residual_1(x, MHA) -> x'
72
+ x' -> Residual_2(x', FFN) -> output
73
+ """
74
+
75
+ def __init__(
76
+ self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
77
+ ) -> None:
78
+ """
79
+ Initializes the Encoder Layer.
80
+
81
+ Args:
82
+ d_model (int): The dimension of the model (D).
83
+ n_heads (int): The number of attention heads (H).
84
+ d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
85
+ dropout (float): The dropout rate for the residual connections.
86
+ """
87
+ super().__init__()
88
+
89
+ self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
90
+
91
+ self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
92
+ d_model, d_ff
93
+ )
94
+
95
+ self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
96
+ self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
97
+
98
+ def forward(
99
+ self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T_k"]
100
+ ) -> Float[Tensor, "B T D"]:
101
+ """
102
+ Forward pass for the Encoder Layer.
103
+
104
+ Args:
105
+ x (Tensor): Input tensor from the previous layer or embedding.
106
+ src_mask (Tensor): The padding mask for the source sentence.
107
+ Shape (B, 1, 1, T_k) allows broadcasting
108
+ to (B, H, T_q, T_k).
109
+
110
+ Returns:
111
+ Tensor: The output tensor of the Encoder Layer.
112
+ """
113
+ x = self.residual_1(
114
+ x,
115
+ lambda x_normed: self.self_attn(
116
+ q=x_normed, k=x_normed, v=x_normed, mask=src_mask
117
+ ),
118
+ )
119
+
120
+ x = self.residual_2(x, self.feed_forward)
121
+
122
+ return x
123
+
124
+
125
+ class Encoder(nn.Module):
126
+ """
127
+ Implements the full Transformer Encoder, which is a stack of N
128
+ identical EncoderLayers.
129
+
130
+ This module takes the input embeddings + positional encodings and
131
+ processes them through N layers of self-attention and FFNs.
132
+
133
+ (Best Practice: Uses Pre-LN, so a final LayerNorm is applied
134
+ at the *end* of the stack, before passing to the Decoder).
135
+ """
136
+
137
+ def __init__(
138
+ self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
139
+ ) -> None:
140
+ """
141
+ Initializes the Encoder stack.
142
+
143
+ Args:
144
+ d_model (int): The dimension of the model (D).
145
+ n_heads (int): The number of attention heads (H).
146
+ d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
147
+ n_layers (int): The number of EncoderLayer blocks to stack (N).
148
+ dropout (float): The dropout rate for the residual connections.
149
+ """
150
+ super().__init__()
151
+
152
+ self.layers: nn.ModuleList = nn.ModuleList(
153
+ [EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
154
+ )
155
+
156
+ self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
157
+
158
+ def forward(
159
+ self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T"]
160
+ ) -> Float[Tensor, "B T D"]:
161
+ """
162
+ Forward pass for the entire Encoder stack.
163
+
164
+ Args:
165
+ x (Tensor): Input tensor (usually token embeddings + pos encodings).
166
+ src_mask (Tensor): The padding mask for the source sentence.
167
+
168
+ Returns:
169
+ Tensor: The output of the final Encoder layer (the "context"
170
+ or "memory" for the Decoder).
171
+ """
172
+
173
+ for layer in self.layers:
174
+ x = layer(x, src_mask)
175
+
176
+ x = self.norm(x)
177
+
178
+ return x
179
+
180
+
181
+ class DecoderLayer(nn.Module):
182
+ """
183
+ Implements one single Decoder Layer (or "Block") of the Transformer Decoder.
184
+
185
+ A Decoder Layer consists of three main sublayers:
186
+ 1. A Masked Multi-Head Self-Attention mechanism (MHA).
187
+ 2. A Multi-Head Cross-Attention mechanism (MHA).
188
+ 3. A Position-wise Feed-Forward Network (FFN).
189
+
190
+ Each sublayer is wrapped by a ResidualConnection (Pre-LN and Dropout).
191
+
192
+ Architecture:
193
+ x -> Residual_1(x, Masked_MHA) -> x'
194
+ x' -> Residual_2(x', Cross_MHA, enc_output) -> x''
195
+ x'' -> Residual_3(x'', FFN) -> output
196
+ """
197
+
198
+ def __init__(
199
+ self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
200
+ ) -> None:
201
+ """
202
+ Initializes the Decoder Layer.
203
+
204
+ Args:
205
+ d_model (int): The dimension of the model (D).
206
+ n_heads (int): The number of attention heads (H).
207
+ d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
208
+ dropout (float): The dropout rate for the residual connections.
209
+ """
210
+ super().__init__()
211
+
212
+ self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
213
+
214
+ self.cross_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
215
+
216
+ self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
217
+ d_model, d_ff
218
+ )
219
+
220
+ self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
221
+ self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
222
+ self.residual_3: ResidualConnection = ResidualConnection(d_model, dropout)
223
+
224
+ def forward(
225
+ self,
226
+ x: Float[Tensor, "B T_tgt D"],
227
+ enc_output: Float[Tensor, "B T_src D"],
228
+ src_mask: Bool[Tensor, "B 1 1 T_src"],
229
+ tgt_mask: Bool[Tensor, "B 1 1 T_tgt"],
230
+ ) -> Float[Tensor, "B T_tgt D"]:
231
+ """
232
+ Forward pass for the Decoder Layer.
233
+
234
+ Args:
235
+ x (Tensor): Input tensor from the previous decoder layer.
236
+ enc_output (Tensor): The output tensor from the Encoder (K, V).
237
+ src_mask (Tensor): The padding mask for the source (Encoder) input.
238
+ tgt_mask (Tensor): The combined look-ahead and padding mask
239
+ for the target (Decoder) input.
240
+
241
+ Returns:
242
+ Tensor: The output tensor of the Decoder Layer.
243
+ """
244
+ x = self.residual_1(
245
+ x,
246
+ lambda x_normed: self.self_attn(
247
+ q=x_normed, k=x_normed, v=x_normed, mask=tgt_mask
248
+ ),
249
+ )
250
+
251
+ x = self.residual_2(
252
+ x,
253
+ lambda x_normed: self.cross_attn(
254
+ q=x_normed, k=enc_output, v=enc_output, mask=src_mask
255
+ ),
256
+ )
257
+
258
+ x = self.residual_3(x, self.feed_forward)
259
+
260
+ return x
261
+
262
+
263
+ class Decoder(nn.Module):
264
+ """
265
+ Implements the full Transformer Decoder, which is a stack of N
266
+ identical DecoderLayers.
267
+
268
+ This module takes the target embeddings + positional encodings and
269
+ processes them through N layers of masked self-attention,
270
+ cross-attention, and FFNs.
271
+
272
+ (Best Practice: Uses Pre-LN, so a final LayerNorm is applied
273
+ at the *end* of the stack, before passing to the final Generator).
274
+ """
275
+
276
+ def __init__(
277
+ self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
278
+ ) -> None:
279
+ """
280
+ Initializes the Decoder stack.
281
+
282
+ Args:
283
+ d_model (int): The dimension of the model (D).
284
+ n_heads (int): The number of attention heads (H).
285
+ d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
286
+ n_layers (int): The number of DecoderLayer blocks to stack (N).
287
+ dropout (float): The dropout rate for the residual connections.
288
+ """
289
+ super().__init__()
290
+
291
+ self.layers: nn.ModuleList = nn.ModuleList(
292
+ [DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
293
+ )
294
+
295
+ self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
296
+
297
+ def forward(
298
+ self,
299
+ x: Float[Tensor, "B T_tgt D"],
300
+ enc_output: Float[Tensor, "B T_src D"],
301
+ src_mask: Bool[Tensor, "B 1 1 T_src"],
302
+ tgt_mask: Bool[Tensor, "1 1 T_tgt T_tgt"],
303
+ ) -> Float[Tensor, "B T_tgt D"]:
304
+ """
305
+ Forward pass for the entire Decoder stack.
306
+
307
+ Args:
308
+ x (Tensor): Input tensor for the target (embeddings + pos enc).
309
+ enc_output (Tensor): The output from the Encoder (K, V for cross-attn).
310
+ src_mask (Tensor): Padding mask for the source (Encoder) sequence.
311
+ tgt_mask (Tensor): Combined mask for the target (Decoder) sequence.
312
+
313
+ Returns:
314
+ Tensor: The output of the final Decoder layer, ready for the
315
+ final projection (Generator).
316
+ """
317
+
318
+ for layer in self.layers:
319
+ x = layer(x, enc_output, src_mask, tgt_mask)
320
+
321
+ x = self.norm(x)
322
+
323
+ return x
src/streamlit_app.py CHANGED
@@ -1,40 +1,178 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import time
3
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
4
+ from huggingface_hub import hf_hub_download
5
+ import config
6
+ import model
7
+ import utils
8
 
9
+
10
+ # ==========================================
11
+ # 1. ASSUMPTIONS
12
+ # ==========================================
13
+
14
+
15
+ @st.cache_resource
16
+ def load_artifacts():
17
+ tokenizer: PreTrainedTokenizerFast = None
18
+ transformer_model: model.Transformer = None
19
+
20
+ try:
21
+ tok_path = hf_hub_download(
22
+ repo_id=config.REPO_ID, filename="iwslt_en-vi_tokenizer_32k.json"
23
+ )
24
+ tokenizer = utils.load_tokenizer(tok_path)
25
+
26
+ print("Loading model for inference...")
27
+ transformer_model = model.load_trained_model(
28
+ config, config.MODEL_SAVE_PATH, config.DEVICE
29
+ )
30
+
31
+ except Exception as e:
32
+ print(
33
+ f"Warning: Could not load model. Using RANDOMLY initialized model. Error: {e}"
34
+ )
35
+ print(" (Translations will be gibberish)")
36
+
37
+ return transformer_model, tokenizer
38
+
39
+
40
+ # ==========================================
41
+ # 2. UI CONFIGURATION
42
+ # ==========================================
43
+ st.set_page_config(
44
+ page_title="En-Vi Translator | AttentionIsAllYouBuild",
45
+ page_icon="🤖",
46
+ layout="centered",
47
+ # layout="wide",
48
+ )
49
+
50
+ # Customize CSS to create beautiful interface
51
+ st.markdown(
52
+ """
53
+ <style>
54
+ .main {
55
+ background-color: #f5f5f5;
56
+ }
57
+ .stTextArea textarea {
58
+ font-size: 16px;
59
+ }
60
+ .stButton button {
61
+ width: 100%;
62
+ background-color: #FF4B4B;
63
+ color: white;
64
+ font-weight: bold;
65
+ padding: 10px;
66
+ }
67
+ .result-box {
68
+ background-color: #ffffff;
69
+ padding: 20px;
70
+ border-radius: 10px;
71
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
72
+ border-left: 5px solid #FF4B4B;
73
+ }
74
+ .source-text {
75
+ color: #666;
76
+ font-style: italic;
77
+ font-size: 14px;
78
+ margin-bottom: 5px;
79
+ }
80
+ .translated-text {
81
+ color: #333;
82
+ font-size: 20px;
83
+ font-weight: 600;
84
+ }
85
+ </style>
86
+ """,
87
+ unsafe_allow_html=True,
88
+ )
89
+
90
+ # ==========================================
91
+ # 3. MAIN APP LAYOUT
92
+ # ==========================================
93
+
94
+ # Header
95
+ st.title("🤖 AI Translator: English → Vietnamese")
96
+ st.markdown("### Project: *Attention Is All You Build*")
97
+ st.markdown("---")
98
+
99
+ # Sidebar (Thông tin thêm)
100
+ with st.sidebar:
101
+ st.header("ℹ️ Thông tin Model")
102
+ st.info(
103
+ """
104
+ Đây là mô hình **Transformer (Encoder-Decoder)** được xây dựng "from scratch" bằng PyTorch.
105
+
106
+ - **Kiến trúc**: Pre-LN Transformer
107
+ - **Tokenizer**: BPE (32k vocab)
108
+ - **Inference**: Greedy
109
+ """
110
+ )
111
+ st.write("Created by [Your Name]")
112
+
113
+ # Input Area
114
+ input_text = st.text_area(
115
+ label="Nhập câu tiếng Anh:",
116
+ placeholder="Example: Artificial intelligence is transforming the world...",
117
+ height=150,
118
+ )
119
+
120
+ # ==========================================
121
+ # 4. INFERENCE LOGIC
122
+ # ==========================================
123
+
124
+ # Nút bấm Dịch
125
+ if st.button("Dịch sang Tiếng Việt (Translate)"):
126
+ if not input_text.strip():
127
+ st.warning("⚠️ Vui lòng nhập nội dung cần dịch!")
128
+ else:
129
+ # Hiển thị spinner trong khi model chạy
130
+ # Display spinner while model is running
131
+ with st.spinner("Wait a second... AI is thinking 🧠"):
132
+ try:
133
+ # Đo thời gian inference
134
+ start_time = time.time()
135
+
136
+ # --- Call translate function ---
137
+ transformer_model, tokenizer = load_artifacts()
138
+
139
+ if utils and transformer_model and tokenizer:
140
+ translation = utils.translate(
141
+ transformer_model,
142
+ tokenizer,
143
+ sentence_en=input_text,
144
+ device=config.DEVICE,
145
+ max_len=config.MAX_SEQ_LEN,
146
+ sos_token_id=config.SOS_TOKEN_ID,
147
+ eos_token_id=config.EOS_TOKEN_ID,
148
+ pad_token_id=config.PAD_TOKEN_ID,
149
+ )
150
+
151
+ else:
152
+ # Mockup output
153
+ time.sleep(1) # Simulate latency
154
+ translation = "[DEMO OUTPUT] Hệ thống chưa load model thực tế. Đây là kết quả mẫu."
155
+
156
+ end_time = time.time()
157
+ inference_time = end_time - start_time
158
+
159
+ # --- Display Result ---
160
+ st.success(f"✅ Hoàn tất trong {inference_time:.2f}s")
161
+
162
+ st.markdown("### Kết quả:")
163
+ st.markdown(
164
+ f"""
165
+ <div class="result-box">
166
+ <div class="source-text">Original: {input_text}</div>
167
+ <div class="translated-text">{translation}</div>
168
+ </div>
169
+ """,
170
+ unsafe_allow_html=True,
171
+ )
172
+
173
+ except Exception as e:
174
+ st.error(f"❌ Đã xảy ra lỗi trong quá trình dịch: {str(e)}")
175
+
176
+ # Footer
177
+ st.markdown("---")
178
+ st.caption("Powered by PyTorch & Streamlit")
src/tokenizer.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from datasets import Dataset
3
+ from tokenizers import (
4
+ Tokenizer,
5
+ models,
6
+ normalizers,
7
+ pre_tokenizers,
8
+ decoders,
9
+ trainers,
10
+ )
11
+ from tqdm.auto import tqdm
12
+ import wandb
13
+ from utils import get_raw_data
14
+
15
+
16
+ DATA_PATH = Path(r"..\data\IWSLT-15-en-vi")
17
+ # TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
18
+ TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
19
+ TOKENIZER_SAVE_PATH = Path(r"..\artifacts\tokenizers") / TOKENIZER_NAME
20
+
21
+ # VOCAB_SIZE: int = 16_000
22
+ VOCAB_SIZE: int = 32_000
23
+ SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]
24
+
25
+ BATCH_SIZE_FOR_TOKENIZER: int = 10000
26
+ NUM_WORKERS: int = 8
27
+
28
+
29
+ def get_training_corpus(dataset: Dataset, batch_size: int = 1000):
30
+ """
31
+ A generator function to yield batches of text.
32
+
33
+ This implementation uses dataset.iter(batch_size=...), which is the
34
+ highly optimized, zero-copy Arrow iterator.
35
+
36
+ We then use list comprehensions to extract the 'en' and 'vi' strings
37
+ from the nested list of dictionaries returned by the iterator.
38
+ """
39
+
40
+ # We iterate over the dataset in batches
41
+ # batch will be: {'translation': [list of 1000 dicts]}
42
+ for batch in dataset.iter(batch_size=batch_size):
43
+
44
+ # We must iterate through the list 'batch['translation']'
45
+ # to extract the individual strings.
46
+
47
+ # This list comprehension is fast and Pythonic.
48
+ en_strings: list[str] = [item["en"] for item in batch["translation"]]
49
+ vi_strings: list[str] = [item["vi"] for item in batch["translation"]]
50
+
51
+ # Yield the batch of strings (which the trainer expects)
52
+ yield en_strings
53
+ yield vi_strings
54
+
55
+
56
+ def instantiate_tokenizer() -> Tokenizer:
57
+ # 1. Initialize an empty Tokenizer with a BPE model
58
+ tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
59
+
60
+ # 2. Set up the normalizer and pre-tokenizer
61
+ # Normalizer: Cleans the text (e.g., Unicode, lowercase)
62
+ tokenizer.normalizer = normalizers.Sequence(
63
+ [
64
+ normalizers.NFKC(), # Unicode normalization
65
+ normalizers.Lowercase(), # Convert to lowercase
66
+ ]
67
+ )
68
+
69
+ # Pre-tokenizer: Splits text into "words" (e.g., by space, punctuation)
70
+ # BPE will then learn to merge sub-words from these.
71
+ tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
72
+
73
+ # Decoder: Reconstructs the string from tokens
74
+ tokenizer.decoder = decoders.BPEDecoder()
75
+
76
+ print("Tokenizer (empty) initialized.")
77
+ return tokenizer
78
+
79
+
80
+ def train_tokenizer():
81
+ # Initialize the BpeTrainer
82
+ trainer = trainers.BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=SPECIAL_TOKENS)
83
+
84
+ print("Tokenizer Trainer initialized.")
85
+
86
+ train_dataset = get_raw_data(DATA_PATH, for_tokenizer=True)
87
+ if not isinstance(train_dataset, Dataset):
88
+ train_dataset = Dataset.from_list(train_dataset)
89
+ print(f"Starting tokenizer training on {len(train_dataset)} pairs...")
90
+
91
+ # 1. Define the iterator AND batch size
92
+ text_iterator = get_training_corpus(
93
+ train_dataset,
94
+ batch_size=BATCH_SIZE_FOR_TOKENIZER,
95
+ )
96
+
97
+ # 2. Calculate total steps for the progress bar
98
+ total_steps = (len(train_dataset) // BATCH_SIZE_FOR_TOKENIZER) * 2
99
+ if total_steps == 0:
100
+ total_steps = 1 # (Avoid division by zero if dataset is tiny)
101
+
102
+ tokenizer: Tokenizer = instantiate_tokenizer()
103
+ # 3. Train with tqdm progress bar
104
+ try:
105
+ tokenizer.train_from_iterator(
106
+ tqdm(
107
+ text_iterator,
108
+ total=total_steps,
109
+ desc="Training Tokenizer (IWSLT-Local)",
110
+ ),
111
+ trainer=trainer,
112
+ length=total_steps,
113
+ )
114
+ except KeyboardInterrupt:
115
+ print("\nTokenizer training interrupted by user.")
116
+
117
+ print("Tokenizer training complete.")
118
+
119
+ tokenizer.save(str(TOKENIZER_SAVE_PATH))
120
+
121
+ print(f"Tokenizer saved to: {TOKENIZER_SAVE_PATH}")
122
+ print(f"Total vocabulary size: {tokenizer.get_vocab_size()}")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ # dataset = get_raw_data()
127
+ # print(type(dataset))
128
+
129
+ # tokenizer: Tokenizer = instantiate_tokenizer()
130
+ # tokenizer.save(str(TOKENIZER_SAVE_PATH))
131
+
132
+ train_tokenizer()
133
+
134
+ run = wandb.init(
135
+ entity="alaindelong-hcmut",
136
+ project="Attention Is All You Build",
137
+ job_type="tokenizer-train",
138
+ )
139
+
140
+ # Log tokenizer
141
+ tokenizer_artifact = wandb.Artifact(
142
+ name="iwslt_en-vi_tokenizer",
143
+ type="tokenizer",
144
+ description="BPE Tokenizer trained on IWSLT 15 (133k+ pairs en-vi)",
145
+ metadata={
146
+ "vocab_size": 32000,
147
+ "algorithm": "BPE",
148
+ "framework": "huggingface",
149
+ "training_data": "iwslt-15-en-vi-133k",
150
+ "lower_case": False,
151
+ },
152
+ )
153
+ tokenizer_artifact.add_file(local_path=str(TOKENIZER_SAVE_PATH))
154
+ run.log_artifact(tokenizer_artifact, aliases=["baseline"])
155
+
156
+ run.finish()
src/utils.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ import re
4
+ from datetime import datetime
5
+ import numpy as np
6
+ from datasets import DatasetDict, Dataset, load_dataset
7
+ import torch
8
+ from torch import Tensor
9
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
10
+ from jaxtyping import Bool, Int
11
+
12
+ # from src import model
13
+ import model
14
+
15
+
16
+ # Utility function to set random seed for reproducibility
17
+ def seed_everything(seed: int = 42) -> None:
18
+ """
19
+ Set random seed for Python, NumPy, and PyTorch to ensure reproducibility.
20
+ Args:
21
+ seed (int): The seed value to use.
22
+ """
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ torch.cuda.manual_seed(seed)
27
+ torch.cuda.manual_seed_all(seed)
28
+ torch.backends.cudnn.deterministic = True
29
+ torch.backends.cudnn.benchmark = False
30
+
31
+
32
+ def make_run_name(model_name: str, d_model: int) -> str:
33
+ time_tag: str = datetime.now().strftime("%Y%m%d_%H%M%S")
34
+ return f"{model_name}-{d_model}d-{time_tag}"
35
+
36
+
37
+ # --- Helper functions for cleaning ---
38
+ def is_valid_pair(example: dict) -> bool:
39
+ """Check if both 'en' and 'vi' strings are non-empty."""
40
+ translation = example.get("translation", {})
41
+ en_text = translation.get("en", "").strip()
42
+ vi_text = translation.get("vi", "").strip()
43
+ return bool(en_text) and bool(vi_text) # (Return True if both are valid)
44
+
45
+
46
+ def filter_empty(dataset: Dataset, num_proc: int) -> Dataset:
47
+ """
48
+ Applies the validation filter to a dataset split using
49
+ parallel processing (via .map() or .filter()).
50
+ """
51
+ print(f" Filtering empty strings from split...")
52
+ # (We use .filter() which is highly optimized)
53
+ original_len = len(dataset)
54
+
55
+ filtered_dataset = dataset.filter(
56
+ is_valid_pair, num_proc=num_proc # (Use parallel processing from config)
57
+ )
58
+
59
+ new_len = len(filtered_dataset)
60
+ print(f" Filtered {original_len - new_len} empty/invalid pairs.")
61
+ return filtered_dataset
62
+
63
+
64
+ # --- Dataset Loading & Splitting ---
65
+ def get_raw_data(
66
+ dataset_path: str | Path, for_tokenizer: bool = False, num_workers: int = 8
67
+ ) -> Dataset | tuple[Dataset, Dataset, Dataset]:
68
+ """
69
+ Load and filter dataset splits from a given path.
70
+
71
+ Args:
72
+ dataset_path (str | Path): Path to the dataset directory or config.
73
+ for_tokenizer (bool): If True, return only filtered train split (for tokenizer training).
74
+ If False, return tuple of (train, validation, test) splits (for model training/eval).
75
+ num_workers (int): Number of workers for parallel filtering.
76
+
77
+ Returns:
78
+ Dataset: Filtered train split (if for_tokenizer=True).
79
+ tuple(Dataset, Dataset, Dataset): Filtered train, validation, test splits (if for_tokenizer=False).
80
+ """
81
+ print(f"Loading datasets from: {dataset_path}")
82
+ all_splits: DatasetDict = load_dataset(path=str(dataset_path))
83
+ print(all_splits)
84
+
85
+ print("--- Filtering Datasets (Removing empty sentences) ---")
86
+ train_data: Dataset = filter_empty(all_splits["train"], num_workers)
87
+ val_data: Dataset = filter_empty(all_splits["validation"], num_workers)
88
+ test_data: Dataset = filter_empty(all_splits["test"], num_workers)
89
+
90
+ if for_tokenizer:
91
+ return train_data
92
+ else:
93
+ return train_data, val_data, test_data
94
+
95
+
96
+ # Utility function to set random seed for reproducibility
97
+ def load_tokenizer(tokenizer_path: str | Path) -> PreTrainedTokenizerFast:
98
+ """
99
+ Load a trained tokenizer from file and return tokenizer object and special token ids.
100
+ Args:
101
+ tokenizer_path (str | Path): Path to the tokenizer JSON file.
102
+ special_tokens (list[str], optional): List of special tokens to get ids for (e.g. ["[PAD]", "[SOS]", "[EOS]", "[UNK]"]).
103
+ Returns:
104
+ tokenizer (Tokenizer): Loaded tokenizer object.
105
+ token_ids (dict): Dictionary of special token ids.
106
+ """
107
+ print(f"Loading tokenizer from {tokenizer_path}...")
108
+ # tokenizer = Tokenizer.from_file(str(tokenizer_path))
109
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
110
+ tokenizer.pad_token = "[PAD]"
111
+ tokenizer.unk_token = "[UNK]"
112
+ tokenizer.bos_token = "[SOS]" # bos = Beginning Of Sentence
113
+ tokenizer.eos_token = "[EOS]" # eos = End Of Sentence
114
+ return tokenizer
115
+
116
+
117
+ def create_padding_mask(
118
+ input_ids: Int[Tensor, "B T_k"], pad_token_id: int
119
+ ) -> Bool[Tensor, "B 1 1 T_k"]:
120
+ """
121
+ Creates a padding mask for the attention mechanism.
122
+
123
+ This mask identifies positions holding the <PAD> token
124
+ and prepares a mask tensor that, when broadcasted, will mask
125
+ these positions in the attention scores matrix (B, H, T_q, T_k).
126
+
127
+ Args:
128
+ input_ids (Tensor): The input token IDs. Shape (B, T_k).
129
+ pad_token_id (int): The ID of the padding token.
130
+
131
+ Returns:
132
+ Tensor: A boolean mask of shape (B, 1, 1, T_k).
133
+ 'True' means "keep" (not a pad token).
134
+ 'False' means "mask out" (is a pad token).
135
+ """
136
+
137
+ # 1. Create the base mask
138
+ # (input_ids != pad_token_id) will be True for real tokens, False for PAD
139
+ # Shape: (B, T_k)
140
+ mask: Tensor = input_ids != pad_token_id
141
+
142
+ # 2. Add dimensions for broadcasting
143
+ # We add a dimension for T_q (dim 1) and H (dim 2)
144
+ # Shape: (B, T_k) -> (B, 1, T_k) -> (B, 1, 1, T_k)
145
+ return mask.unsqueeze(1).unsqueeze(2)
146
+
147
+
148
+ def create_look_ahead_mask(seq_len: int) -> Bool[Tensor, "1 1 T_q T_q"]:
149
+ """
150
+ Creates a causal (look-ahead) mask for the Decoder's self-attention.
151
+
152
+ This mask prevents positions from attending to subsequent positions.
153
+ It's a square matrix where the upper triangle (future) is False
154
+ and the lower triangle (past/present) is True.
155
+
156
+ Args:
157
+ seq_len (int): The sequence length (T_q).
158
+ device (torch.device): The device to create the tensor on (e.g., 'cuda').
159
+
160
+ Returns:
161
+ Tensor: A boolean mask of shape (1, 1, T_q, T_q).
162
+ 'True' means "keep" (allowed to see).
163
+ 'False' means "mask out" (future token).
164
+ """
165
+
166
+ # 1. Create a square matrix of ones.
167
+ # Shape: (T_q, T_q)
168
+ ones = torch.ones(seq_len, seq_len)
169
+
170
+ # 2. Get the lower triangular part (bao gồm đường chéo)
171
+ # This sets the upper triangle (future) to 0 and keeps the rest 1.
172
+ # Shape: (T_q, T_q)
173
+ # Example (T_q=3):
174
+ # [[1., 0., 0.],
175
+ # [1., 1., 0.],
176
+ # [1., 1., 1.]]
177
+ lower_triangular: Tensor = torch.tril(ones)
178
+
179
+ # 3. Convert to boolean and add broadcasting dimensions
180
+ # Shape: (T_q, T_q) -> (1, 1, T_q, T_q)
181
+ # (mask == 1) converts 1. to True, 0. to False
182
+ return (lower_triangular == 1).unsqueeze(0).unsqueeze(0)
183
+
184
+
185
+ def greedy_decode_sentence(
186
+ model: model.Transformer,
187
+ src: Int[Tensor, "1 T_src"], # Input: one sentence
188
+ src_mask: Bool[Tensor, "1 1 1 T_src"],
189
+ max_len: int,
190
+ sos_token_id: int,
191
+ eos_token_id: int,
192
+ device: torch.device,
193
+ ) -> Int[Tensor, "1 T_out"]:
194
+ """
195
+ Performs greedy decoding for a single sentence.
196
+ This is an autoregressive process (token by token).
197
+
198
+ Args:
199
+ model: The trained Transformer model (already on device).
200
+ src: The source token IDs (e.g., English).
201
+ src_mask: The padding mask for the source.
202
+ max_len: The maximum length to generate.
203
+ sos_token_id: The ID for [SOS] token.
204
+ eos_token_id: The ID for [EOS] token.
205
+ device: The device to run on.
206
+
207
+ Returns:
208
+ Tensor: The generated target token IDs (e.g., Vietnamese).
209
+ """
210
+
211
+ # Set model to eval mode (disables dropout)
212
+ model.eval()
213
+
214
+ # No gradients needed
215
+ with torch.no_grad():
216
+
217
+ # --- 1. Encode the source *once* ---
218
+ # (B, T_src) -> (B, T_src, D)
219
+ src_embedded = model.src_embed(src)
220
+ src_with_pos = model.pos_enc(src_embedded)
221
+ enc_output: Tensor = model.encoder(src_with_pos, src_mask)
222
+
223
+ # --- 2. Initialize the Decoder input ---
224
+ # Start with the [SOS] token. Shape: (1, 1)
225
+ decoder_input: Tensor = torch.tensor(
226
+ [[sos_token_id]], dtype=torch.long, device=device
227
+ ) # Shape: (B=1, T_tgt=1)
228
+
229
+ # --- 3. Autoregressive Loop ---
230
+ for _ in range(max_len - 1): # (Max length - 1, since we have [SOS])
231
+
232
+ # --- a. Get Target Embedding + Position ---
233
+ # (B, T_tgt) -> (B, T_tgt, D)
234
+ tgt_embedded = model.tgt_embed(decoder_input)
235
+ tgt_with_pos = model.pos_enc(tgt_embedded)
236
+
237
+ # --- b. Create Target Mask (Causal) ---
238
+ # We must re-create the mask every loop,
239
+ # as T_tgt (decoder_input.size(1)) is growing.
240
+ # Shape: (1, 1, T_tgt, T_tgt)
241
+ T_tgt = decoder_input.size(1)
242
+ tgt_mask = create_look_ahead_mask(T_tgt).to(device)
243
+
244
+ # --- c. Run Decoder and Generator ---
245
+ # (B, T_tgt, D)
246
+ dec_output: Tensor = model.decoder(
247
+ tgt_with_pos, enc_output, src_mask, tgt_mask
248
+ )
249
+ # (B, T_tgt, vocab_size)
250
+ logits: Tensor = model.generator(dec_output)
251
+
252
+ # --- d. Get the *last* token's logits ---
253
+ # (B, T_tgt, vocab_size) -> (B, vocab_size)
254
+ last_token_logits = logits[:, -1, :]
255
+
256
+ # --- e. Greedy Search (get highest prob. token) ---
257
+ # (B, vocab_size) -> (B, 1)
258
+ next_token: Tensor = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1)
259
+
260
+ # --- f. Append the new token ---
261
+ # (B, T_tgt) + (B, 1) -> (B, T_tgt + 1)
262
+ decoder_input = torch.cat([decoder_input, next_token], dim=1)
263
+
264
+ # --- g. Check for [EOS] ---
265
+ # If the *last* token we added is [EOS], stop generating.
266
+ if next_token.item() == eos_token_id:
267
+ break
268
+
269
+ return decoder_input.squeeze(0) # Return shape (T_out)
270
+
271
+
272
+ def filter_and_detokenize(token_list: list[str], skip_special: bool = True) -> str:
273
+ """
274
+ Manually joins tokens with a space and cleans up common
275
+ punctuation issues caused by whitespace tokenization.
276
+ """
277
+ if skip_special:
278
+ # 1. Filter out special tokens
279
+ special_tokens = {"[PAD]", "[UNK]", "[SOS]", "[EOS]"}
280
+ token_list = [tok for tok in token_list if tok not in special_tokens]
281
+
282
+ # 2. Join with spaces
283
+ detokenized_string = " ".join(token_list)
284
+
285
+ # 3. Clean up punctuation
286
+ # (This is a simple heuristic-based detokenizer)
287
+ # Remove space before punctuation: "project ." -> "project."
288
+ detokenized_string = re.sub(r'\s([.,!?\'":;])', r"\1", detokenized_string)
289
+ # Handle contractions: "don 't" -> "don't"
290
+ detokenized_string = re.sub(r"(\w)\s(\'\w)", r"\1\2", detokenized_string)
291
+
292
+ return detokenized_string
293
+
294
+
295
+ # Define a high-level, production-ready
296
+ # inference function that handles all steps.
297
+ def translate(
298
+ model: model.Transformer,
299
+ tokenizer: PreTrainedTokenizerFast,
300
+ sentence_en: str,
301
+ device: torch.device,
302
+ max_len: int,
303
+ sos_token_id: int,
304
+ eos_token_id: int,
305
+ pad_token_id: int,
306
+ ) -> str:
307
+ """
308
+ Translates a single English sentence to Vietnamese.
309
+
310
+ Args:
311
+ model: The trained Transformer model.
312
+ tokenizer: The (PreTrainedTokenizerFast) tokenizer.
313
+ sentence_en: The raw English input string.
314
+ device: The device to run on.
315
+ max_len: The max sequence length (from config).
316
+ sos_token_id: The ID for [SOS].
317
+ eos_token_id: The ID for [EOS].
318
+ pad_token_id: The ID for [PAD].
319
+
320
+ Returns:
321
+ str: The translated Vietnamese string.
322
+ """
323
+
324
+ # Set model to evaluation mode
325
+ model.eval()
326
+
327
+ # Run inference in a no-gradient context
328
+ with torch.no_grad():
329
+
330
+ # 1. Tokenize the source (English) sentence
331
+ src_encoding = tokenizer(
332
+ sentence_en,
333
+ truncation=True,
334
+ max_length=max_len,
335
+ add_special_tokens=False, # (Encoder does not need SOS/EOS)
336
+ )
337
+
338
+ # 2. Convert to Tensor, add Batch dimension (B=1), and move to device
339
+ # Shape: (1, T_src)
340
+ src_ids: Tensor = torch.tensor(
341
+ [src_encoding["input_ids"]], dtype=torch.long
342
+ ).to(device)
343
+
344
+ # 3. Create the source padding mask
345
+ # Shape: (1, 1, 1, T_src)
346
+ src_mask: Tensor = create_padding_mask(src_ids, pad_token_id).to(device)
347
+
348
+ # 4. Generate the target (Vietnamese) token IDs
349
+ # (This calls the autoregressive function from Cell 16A)
350
+ # Shape: (T_out)
351
+ predicted_ids: Tensor = greedy_decode_sentence(
352
+ model,
353
+ src_ids,
354
+ src_mask,
355
+ max_len=max_len,
356
+ sos_token_id=sos_token_id,
357
+ eos_token_id=eos_token_id,
358
+ device=device,
359
+ )
360
+
361
+ # 5. Detokenize (Fixing "sticky" words)
362
+
363
+ # Convert 1D GPU Tensor -> 1D CPU List
364
+ predicted_id_list = predicted_ids.cpu().tolist()
365
+
366
+ # This call is safe (1D List -> List[str])
367
+ predicted_token_list = tokenizer.convert_ids_to_tokens(predicted_id_list)
368
+
369
+ # Use our helper (from Cell 16B) to
370
+ # join with spaces, remove special tokens, and fix punctuation.
371
+ result_string = filter_and_detokenize(predicted_token_list, skip_special=True)
372
+
373
+ return result_string
374
+
375
+ print("Inference function `translate()` defined.")