Commit
·
e27ab6a
1
Parent(s):
1a0fc46
Create translate app
Browse files- .gitignore +227 -0
- Dockerfile +9 -1
- requirements.txt +4 -1
- src/callbacks.py +24 -0
- src/config.py +70 -0
- src/dataset.py +280 -0
- src/embedding.py +105 -0
- src/engine.py +278 -0
- src/layers.py +186 -0
- src/model.py +207 -0
- src/modules.py +323 -0
- src/streamlit_app.py +176 -38
- src/tokenizer.py +156 -0
- src/utils.py +375 -0
.gitignore
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[codz]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py.cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
#poetry.toml
|
| 110 |
+
|
| 111 |
+
# pdm
|
| 112 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 113 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 114 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 115 |
+
#pdm.lock
|
| 116 |
+
#pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# pixi
|
| 121 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 122 |
+
#pixi.lock
|
| 123 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 124 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 125 |
+
.pixi
|
| 126 |
+
|
| 127 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 128 |
+
__pypackages__/
|
| 129 |
+
|
| 130 |
+
# Celery stuff
|
| 131 |
+
celerybeat-schedule
|
| 132 |
+
celerybeat.pid
|
| 133 |
+
|
| 134 |
+
# SageMath parsed files
|
| 135 |
+
*.sage.py
|
| 136 |
+
|
| 137 |
+
# Environments
|
| 138 |
+
.env
|
| 139 |
+
.envrc
|
| 140 |
+
.venv
|
| 141 |
+
env/
|
| 142 |
+
venv/
|
| 143 |
+
ENV/
|
| 144 |
+
env.bak/
|
| 145 |
+
venv.bak/
|
| 146 |
+
|
| 147 |
+
# Spyder project settings
|
| 148 |
+
.spyderproject
|
| 149 |
+
.spyproject
|
| 150 |
+
|
| 151 |
+
# Rope project settings
|
| 152 |
+
.ropeproject
|
| 153 |
+
|
| 154 |
+
# mkdocs documentation
|
| 155 |
+
/site
|
| 156 |
+
|
| 157 |
+
# mypy
|
| 158 |
+
.mypy_cache/
|
| 159 |
+
.dmypy.json
|
| 160 |
+
dmypy.json
|
| 161 |
+
|
| 162 |
+
# Pyre type checker
|
| 163 |
+
.pyre/
|
| 164 |
+
|
| 165 |
+
# pytype static type analyzer
|
| 166 |
+
.pytype/
|
| 167 |
+
|
| 168 |
+
# Cython debug symbols
|
| 169 |
+
cython_debug/
|
| 170 |
+
|
| 171 |
+
# PyCharm
|
| 172 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 173 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 174 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 175 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 176 |
+
#.idea/
|
| 177 |
+
|
| 178 |
+
# Abstra
|
| 179 |
+
# Abstra is an AI-powered process automation framework.
|
| 180 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 181 |
+
# Learn more at https://abstra.io/docs
|
| 182 |
+
.abstra/
|
| 183 |
+
|
| 184 |
+
# Visual Studio Code
|
| 185 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 186 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 187 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 188 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 189 |
+
# .vscode/
|
| 190 |
+
|
| 191 |
+
# Ruff stuff:
|
| 192 |
+
.ruff_cache/
|
| 193 |
+
|
| 194 |
+
# PyPI configuration file
|
| 195 |
+
.pypirc
|
| 196 |
+
|
| 197 |
+
# Cursor
|
| 198 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 199 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 200 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 201 |
+
.cursorignore
|
| 202 |
+
.cursorindexingignore
|
| 203 |
+
|
| 204 |
+
# Marimo
|
| 205 |
+
marimo/_static/
|
| 206 |
+
marimo/_lsp/
|
| 207 |
+
__marimo__/
|
| 208 |
+
|
| 209 |
+
# # Image
|
| 210 |
+
# images/
|
| 211 |
+
|
| 212 |
+
# Dataset
|
| 213 |
+
data/en-vi.txt/
|
| 214 |
+
data/IWSLT'15 en-vi/
|
| 215 |
+
notebooks/processed_data/
|
| 216 |
+
notebooks/IWSLT-15-en-vi/
|
| 217 |
+
|
| 218 |
+
# MLflow
|
| 219 |
+
mlruns/
|
| 220 |
+
|
| 221 |
+
# Temp Files
|
| 222 |
+
scratch/
|
| 223 |
+
notebooks/
|
| 224 |
+
test_push_to_hub.ipynb
|
| 225 |
+
|
| 226 |
+
# Weights & Biases
|
| 227 |
+
wandb/
|
Dockerfile
CHANGED
|
@@ -13,8 +13,16 @@ COPY src/ ./src/
|
|
| 13 |
|
| 14 |
RUN pip3 install -r requirements.txt
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
EXPOSE 8501
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
|
|
| 13 |
|
| 14 |
RUN pip3 install -r requirements.txt
|
| 15 |
|
| 16 |
+
RUN mkdir -p /app/hf_cache
|
| 17 |
+
|
| 18 |
+
ENV HF_HOME="/app/hf_cache"
|
| 19 |
+
|
| 20 |
+
RUN chmod -R 777 /app
|
| 21 |
+
|
| 22 |
+
|
| 23 |
EXPOSE 8501
|
| 24 |
|
| 25 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 26 |
|
| 27 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
| 28 |
+
# ENTRYPOINT ["streamlit", "run", "streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
-
streamlit
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
altair
|
| 2 |
pandas
|
| 3 |
+
streamlit
|
| 4 |
+
torch==2.6.0
|
| 5 |
+
transformers==4.52.4
|
| 6 |
+
jaxtyping
|
src/callbacks.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class EarlyStopping:
|
| 2 |
+
def __init__(self, patience=5, min_delta=1e-4, verbose=True):
|
| 3 |
+
self.patience = patience
|
| 4 |
+
self.min_delta = min_delta
|
| 5 |
+
self.verbose = verbose
|
| 6 |
+
self.best_loss = float("inf")
|
| 7 |
+
self.counter = 0
|
| 8 |
+
self.should_stop = False
|
| 9 |
+
|
| 10 |
+
def step(self, val_loss):
|
| 11 |
+
# Check improvement
|
| 12 |
+
if val_loss < self.best_loss - self.min_delta:
|
| 13 |
+
self.best_loss = val_loss
|
| 14 |
+
self.counter = 0
|
| 15 |
+
else:
|
| 16 |
+
self.counter += 1
|
| 17 |
+
|
| 18 |
+
# Stop condition
|
| 19 |
+
if self.counter >= self.patience:
|
| 20 |
+
self.should_stop = True
|
| 21 |
+
if self.verbose:
|
| 22 |
+
print(
|
| 23 |
+
f"[EarlyStopping] No improvement for {self.patience} epochs → stopping."
|
| 24 |
+
)
|
src/config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# Path Configuration
|
| 5 |
+
DATA_PATH = Path(r"data\IWSLT-15-en-vi")
|
| 6 |
+
|
| 7 |
+
# TOKENIZER_NAME = ""
|
| 8 |
+
# TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
|
| 9 |
+
TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
|
| 10 |
+
TOKENIZER_PATH = Path(r"artifacts\tokenizers") / TOKENIZER_NAME
|
| 11 |
+
|
| 12 |
+
MODEL_DIR = Path(r"artifacts\models")
|
| 13 |
+
|
| 14 |
+
# MODEL_NAME = ""
|
| 15 |
+
# MODEL_NAME = "transformer_en_vi_iwslt_1.pt"
|
| 16 |
+
MODEL_NAME = "transformer_en_vi_iwslt_1.safetensors"
|
| 17 |
+
|
| 18 |
+
# MODEL_SAVE_PATH = MODEL_DIR / MODEL_NAME
|
| 19 |
+
MODEL_SAVE_PATH = MODEL_DIR / "transformer_en_vi_iwslt_kaggle_1.safetensors"
|
| 20 |
+
# MODEL_SAVE_PATH = Path(r"notebooks\models") / MODEL_NAME
|
| 21 |
+
|
| 22 |
+
CHECKPOINT_PATH = Path(r"artifacts\checkpoints") / MODEL_NAME
|
| 23 |
+
|
| 24 |
+
CACHE_DIR = ""
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Hardware & Data Config
|
| 28 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
NUM_WORKERS: int = 4
|
| 31 |
+
|
| 32 |
+
VOCAB_SIZE: int = 32_000
|
| 33 |
+
|
| 34 |
+
SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]
|
| 35 |
+
|
| 36 |
+
NUM_SAMPLES_TO_USE: int = 1000
|
| 37 |
+
# NUM_SAMPLES_TO_USE: int = 1_000_000
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Tokenizer Constants
|
| 41 |
+
PAD_TOKEN_ID: int = 0
|
| 42 |
+
UNK_TOKEN_ID: int = 1
|
| 43 |
+
SOS_TOKEN_ID: int = 2
|
| 44 |
+
EOS_TOKEN_ID: int = 3
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Model Hyperparameters
|
| 48 |
+
# D_MODEL: int = 256 # (Dimension of model)
|
| 49 |
+
D_MODEL: int = 512
|
| 50 |
+
N_LAYERS: int = 6 # (N=6 in paper)
|
| 51 |
+
N_HEADS: int = 8 # (h=8 in paper)
|
| 52 |
+
# D_FF: int = 1024 # (d_ff = 4 * d_model = 1024)
|
| 53 |
+
D_FF: int = 2048
|
| 54 |
+
DROPOUT: float = 0.1 # (Dropout = 0.1 in paper)
|
| 55 |
+
MAX_SEQ_LEN: int = 150 # (Max length for Positional Encoding)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Training Configuration
|
| 59 |
+
# LEARNING_RATE: float = 1e-4
|
| 60 |
+
LEARNING_RATE: float = 5e-4
|
| 61 |
+
BATCH_SIZE: int = 32
|
| 62 |
+
EPOCHS: int = 5
|
| 63 |
+
# EPOCHS: int = 50
|
| 64 |
+
|
| 65 |
+
# HuggingFace
|
| 66 |
+
REPO_ID: str = "AlainDeLong/transformer-en-vi-base"
|
| 67 |
+
FILENAME: str = "transformer_en_vi_iwslt_kaggle_1.safetensors"
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
print(f"Using device: {DEVICE}")
|
src/dataset.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from datasets import Dataset as ArrowDataset
|
| 6 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 7 |
+
|
| 8 |
+
import config
|
| 9 |
+
from src import utils
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TranslationDataset(Dataset):
|
| 13 |
+
"""
|
| 14 |
+
A "lazy" Dataset.
|
| 15 |
+
Uses the high-level PreTrainedTokenizerFast wrapper.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dataset: ArrowDataset,
|
| 21 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 22 |
+
max_len_src: int,
|
| 23 |
+
max_len_tgt: int,
|
| 24 |
+
src_lang: str = "en",
|
| 25 |
+
tgt_lang: str = "vi",
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.dataset = dataset
|
| 29 |
+
self.tokenizer = tokenizer
|
| 30 |
+
self.max_len_src = max_len_src
|
| 31 |
+
self.max_len_tgt = max_len_tgt
|
| 32 |
+
self.src_lang = src_lang
|
| 33 |
+
self.tgt_lang = tgt_lang
|
| 34 |
+
|
| 35 |
+
def __len__(self) -> int:
|
| 36 |
+
return len(self.dataset)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, index: int) -> dict[str, list[int]]:
|
| 39 |
+
|
| 40 |
+
item = self.dataset[index]["translation"]
|
| 41 |
+
src_text = item[self.src_lang]
|
| 42 |
+
tgt_text = item[self.tgt_lang]
|
| 43 |
+
|
| 44 |
+
# We set add_special_tokens=False for manual control.
|
| 45 |
+
src_encoding = self.tokenizer(
|
| 46 |
+
src_text,
|
| 47 |
+
truncation=True,
|
| 48 |
+
max_length=self.max_len_src,
|
| 49 |
+
add_special_tokens=False, # (Source has no SOS/EOS)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
tgt_encoding = self.tokenizer(
|
| 53 |
+
tgt_text,
|
| 54 |
+
truncation=True,
|
| 55 |
+
max_length=self.max_len_tgt - 2, # (Reserve 2 spots for SOS/EOS)
|
| 56 |
+
add_special_tokens=False,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Manually add SOS/EOS to target
|
| 60 |
+
src_ids = src_encoding["input_ids"]
|
| 61 |
+
|
| 62 |
+
tgt_ids = (
|
| 63 |
+
[config.SOS_TOKEN_ID] + tgt_encoding["input_ids"] + [config.EOS_TOKEN_ID]
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return {"src_ids": src_ids, "tgt_ids": tgt_ids}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DataCollator:
|
| 70 |
+
"""
|
| 71 |
+
Implements a custom collate_fn.
|
| 72 |
+
|
| 73 |
+
1. Takes a list of dicts (from __getitem__)
|
| 74 |
+
2. Adds SOS/EOS (Wait, we did this in Dataset)
|
| 75 |
+
3. Creates decoder inputs and labels (shifted)
|
| 76 |
+
4. Dynamically pads all sequences *in the batch*
|
| 77 |
+
5. Creates all 3 required masks
|
| 78 |
+
6. Returns a single dict of tensors
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, pad_token_id: int):
|
| 82 |
+
self.pad_token_id = pad_token_id
|
| 83 |
+
|
| 84 |
+
def __call__(self, batch: list[dict[str, list[int]]]) -> dict[str, Tensor]:
|
| 85 |
+
|
| 86 |
+
# 1. Get raw ID lists from the batch
|
| 87 |
+
src_ids_list = [item["src_ids"] for item in batch]
|
| 88 |
+
tgt_ids_list = [item["tgt_ids"] for item in batch] # (Already has SOS/EOS)
|
| 89 |
+
|
| 90 |
+
# 2. Create shifted inputs/labels
|
| 91 |
+
# Decoder input (T_tgt): [SOS, w1, w2, w3]
|
| 92 |
+
dec_input_ids_list = [ids[:-1] for ids in tgt_ids_list]
|
| 93 |
+
# Label (T_tgt): [w1, w2, w3, EOS]
|
| 94 |
+
labels_list = [ids[1:] for ids in tgt_ids_list]
|
| 95 |
+
|
| 96 |
+
# 3. Dynamic Padding
|
| 97 |
+
# We use torch.nn.utils.rnn.pad_sequence
|
| 98 |
+
# (Note: batch_first=True means (B, T))
|
| 99 |
+
src_ids_padded = nn.utils.rnn.pad_sequence(
|
| 100 |
+
[torch.tensor(ids) for ids in src_ids_list],
|
| 101 |
+
batch_first=True,
|
| 102 |
+
padding_value=self.pad_token_id,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
dec_input_ids_padded = nn.utils.rnn.pad_sequence(
|
| 106 |
+
[torch.tensor(ids) for ids in dec_input_ids_list],
|
| 107 |
+
batch_first=True,
|
| 108 |
+
padding_value=self.pad_token_id,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
labels_padded = nn.utils.rnn.pad_sequence(
|
| 112 |
+
[torch.tensor(ids) for ids in labels_list],
|
| 113 |
+
batch_first=True,
|
| 114 |
+
padding_value=self.pad_token_id, # (Loss will ignore this ID)
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 4. Get the sequence length
|
| 118 |
+
_, T_tgt = dec_input_ids_padded.shape
|
| 119 |
+
|
| 120 |
+
# 5. Create Masks (on CPU)
|
| 121 |
+
|
| 122 |
+
# (Mask 1) Source padding mask (for Encoder MHA & Cross-Attn)
|
| 123 |
+
# Shape: (B, 1, 1, T_src)
|
| 124 |
+
src_mask = utils.create_padding_mask(src_ids_padded, self.pad_token_id)
|
| 125 |
+
|
| 126 |
+
# (Mask 2) Target padding mask (for Decoder MHA)
|
| 127 |
+
# Shape: (B, 1, 1, T_tgt)
|
| 128 |
+
tgt_padding_mask = utils.create_padding_mask(
|
| 129 |
+
dec_input_ids_padded, self.pad_token_id
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# (Mask 3) Target look-ahead mask (for Decoder MHA)
|
| 133 |
+
# Shape: (1, 1, T_tgt, T_tgt)
|
| 134 |
+
look_ahead_mask = utils.create_look_ahead_mask(T_tgt)
|
| 135 |
+
|
| 136 |
+
# (Mask 4) Combined target mask
|
| 137 |
+
# Shape: (B, 1, T_tgt, T_tgt)
|
| 138 |
+
tgt_mask = tgt_padding_mask & look_ahead_mask
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"src_ids": src_ids_padded, # (B, T_src)
|
| 142 |
+
"tgt_input_ids": dec_input_ids_padded, # (B, T_tgt)
|
| 143 |
+
"labels": labels_padded, # (B, T_tgt)
|
| 144 |
+
"src_mask": src_mask, # (B, 1, 1, T_src)
|
| 145 |
+
"tgt_mask": tgt_mask, # (B, 1, T_tgt, T_tgt)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_translation_datasets(
|
| 150 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 151 |
+
) -> tuple[TranslationDataset, TranslationDataset, TranslationDataset]:
|
| 152 |
+
"""
|
| 153 |
+
A Factory function to automate the data pipeline setup.
|
| 154 |
+
|
| 155 |
+
It performs 3 steps:
|
| 156 |
+
1. Loads and cleans raw data (using src.utils).
|
| 157 |
+
2. Instantiates the TranslationDataset for Train, Val, and Test splits.
|
| 158 |
+
3. Returns the 3 PyTorch datasets ready for the DataLoader.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
tokenizer: The trained tokenizer.
|
| 162 |
+
|
| 163 |
+
Returns:
|
| 164 |
+
Tuple containing (train_ds, val_ds, test_ds)
|
| 165 |
+
"""
|
| 166 |
+
|
| 167 |
+
# 1. Load raw cleaned data (returns Dict[str, Dataset])
|
| 168 |
+
# This keeps train.py clean from raw data handling logic.
|
| 169 |
+
train_data, val_data, test_data = utils.get_raw_data(
|
| 170 |
+
config.DATA_PATH, num_workers=config.NUM_WORKERS
|
| 171 |
+
)
|
| 172 |
+
train_data = train_data.select(range(config.NUM_SAMPLES_TO_USE))
|
| 173 |
+
|
| 174 |
+
print(f"Building PyTorch Datasets...")
|
| 175 |
+
|
| 176 |
+
# 2. Instantiate the Train Dataset
|
| 177 |
+
# (Uses global config for max_length)
|
| 178 |
+
train_ds = TranslationDataset(
|
| 179 |
+
dataset=train_data,
|
| 180 |
+
tokenizer=tokenizer,
|
| 181 |
+
max_len_src=config.MAX_SEQ_LEN,
|
| 182 |
+
max_len_tgt=config.MAX_SEQ_LEN,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# 3. Instantiate the Validation Dataset
|
| 186 |
+
val_ds = TranslationDataset(
|
| 187 |
+
dataset=val_data,
|
| 188 |
+
tokenizer=tokenizer,
|
| 189 |
+
max_len_src=config.MAX_SEQ_LEN,
|
| 190 |
+
max_len_tgt=config.MAX_SEQ_LEN,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 4. Instantiate the Test Dataset
|
| 194 |
+
test_ds = TranslationDataset(
|
| 195 |
+
dataset=test_data,
|
| 196 |
+
tokenizer=tokenizer,
|
| 197 |
+
max_len_src=config.MAX_SEQ_LEN,
|
| 198 |
+
max_len_tgt=config.MAX_SEQ_LEN,
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
print(
|
| 202 |
+
f"Datasets created: Train={len(train_ds)}, Val={len(val_ds)}, Test={len(test_ds)}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return train_ds, val_ds, test_ds
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def get_dataloaders(
|
| 209 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 210 |
+
) -> tuple[DataLoader, DataLoader, DataLoader]:
|
| 211 |
+
"""
|
| 212 |
+
A high-level Factory function to create DataLoaders.
|
| 213 |
+
|
| 214 |
+
This function abstracts away all the data pipeline complexity:
|
| 215 |
+
- Loading/Cleaning raw data
|
| 216 |
+
- Creating PyTorch Datasets
|
| 217 |
+
- Instantiating the DataCollator (dynamic padding)
|
| 218 |
+
- Creating DataLoaders with the correct batch size and workers
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
tokenizer: The trained tokenizer.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple containing (train_loader, val_loader, test_loader)
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
# 1. Create the Datasets (using the factory function we made earlier)
|
| 228 |
+
train_ds, val_ds, test_ds = get_translation_datasets(tokenizer)
|
| 229 |
+
|
| 230 |
+
# 2. Instantiate the Collator
|
| 231 |
+
# (We need config to get PAD_TOKEN_ID)
|
| 232 |
+
collator = DataCollator(pad_token_id=config.PAD_TOKEN_ID)
|
| 233 |
+
|
| 234 |
+
print(
|
| 235 |
+
f"Building DataLoaders (Batch Size: {config.BATCH_SIZE}, Workers: {config.NUM_WORKERS})..."
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# 3. Create Train DataLoader
|
| 239 |
+
# (Shuffle = True is CRITICAL for training)
|
| 240 |
+
train_loader = DataLoader(
|
| 241 |
+
train_ds,
|
| 242 |
+
batch_size=config.BATCH_SIZE,
|
| 243 |
+
shuffle=True,
|
| 244 |
+
num_workers=config.NUM_WORKERS,
|
| 245 |
+
collate_fn=collator,
|
| 246 |
+
pin_memory=True if config.DEVICE == "cuda" else False, # (Optimization)
|
| 247 |
+
prefetch_factor=2,
|
| 248 |
+
persistent_workers=True,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
# 4. Create Validation DataLoader
|
| 252 |
+
# (Shuffle = False for reproducible validation)
|
| 253 |
+
val_loader = DataLoader(
|
| 254 |
+
val_ds,
|
| 255 |
+
batch_size=2 * config.BATCH_SIZE,
|
| 256 |
+
shuffle=False,
|
| 257 |
+
num_workers=config.NUM_WORKERS,
|
| 258 |
+
collate_fn=collator,
|
| 259 |
+
pin_memory=True if config.DEVICE == "cuda" else False,
|
| 260 |
+
prefetch_factor=2,
|
| 261 |
+
persistent_workers=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# 5. Create Test DataLoader
|
| 265 |
+
test_loader = DataLoader(
|
| 266 |
+
test_ds,
|
| 267 |
+
batch_size=2 * config.BATCH_SIZE,
|
| 268 |
+
shuffle=False,
|
| 269 |
+
num_workers=2,
|
| 270 |
+
# num_workers=config.NUM_WORKERS,
|
| 271 |
+
collate_fn=collator,
|
| 272 |
+
pin_memory=True if config.DEVICE == "cuda" else False,
|
| 273 |
+
prefetch_factor=2,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print(f"DataLoader (train) created with {len(train_loader)} batches.")
|
| 277 |
+
print(f"DataLoader (val) created with {len(val_loader)} batches.")
|
| 278 |
+
print(f"DataLoader (test) created with {len(test_loader)} batches.")
|
| 279 |
+
|
| 280 |
+
return train_loader, val_loader, test_loader
|
src/embedding.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from jaxtyping import Int, Float
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InputEmbeddings(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Implements the Input Embedding layer.
|
| 11 |
+
|
| 12 |
+
This module converts a tensor of token IDs into a tensor of
|
| 13 |
+
corresponding embedding vectors. It also scales the embeddings
|
| 14 |
+
by sqrt(d_model) as mentioned in the paper ("Attention Is All You Need",
|
| 15 |
+
Section 3.4).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
| 19 |
+
"""
|
| 20 |
+
Initializes the InputEmbedding layer.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
d_model (int): The dimension of the embedding vector (D).
|
| 24 |
+
vocab_size (int): The size of the vocabulary.
|
| 25 |
+
"""
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.d_model: int = d_model
|
| 29 |
+
self.vocab_size: int = vocab_size
|
| 30 |
+
|
| 31 |
+
self.token_emb: nn.Embedding = nn.Embedding(vocab_size, d_model)
|
| 32 |
+
|
| 33 |
+
def forward(self, x: Int[Tensor, "B T"]) -> Float[Tensor, "B T D"]:
|
| 34 |
+
"""
|
| 35 |
+
Forward pass for the InputEmbeddings.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
x (Tensor): Input tensor of token IDs. Shape (B, T). B: batch_size, T: seq_len
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Tensor: The corresponding embedding vectors, scaled by sqrt(d_model).
|
| 42 |
+
Shape (B, T, D).
|
| 43 |
+
"""
|
| 44 |
+
# (B, T) -> (B, T, D)
|
| 45 |
+
embeddings = self.token_emb(x)
|
| 46 |
+
|
| 47 |
+
return embeddings * math.sqrt(self.d_model)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class PositionalEncoding(nn.Module):
|
| 51 |
+
"""
|
| 52 |
+
Implements the fixed (sin/cos) Positional Encoding module.
|
| 53 |
+
(Ref: "Attention Is All You Need", Section 3.5)
|
| 54 |
+
|
| 55 |
+
This module generates a tensor of positional encodings that are
|
| 56 |
+
added to the input embeddings. It also applies dropout to the
|
| 57 |
+
sum of the embeddings and the positional encodings.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1) -> None:
|
| 61 |
+
"""
|
| 62 |
+
Initializes the PositionalEncoding module.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
d_model (int): The dimension of the model (D).
|
| 66 |
+
max_seq_len (int): The maximum sequence length (T_max) to pre-compute.
|
| 67 |
+
dropout (float): Dropout probability.
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.dropout: nn.Dropout = nn.Dropout(p=dropout)
|
| 72 |
+
|
| 73 |
+
position: Tensor = torch.arange(max_seq_len).unsqueeze(1).float()
|
| 74 |
+
|
| 75 |
+
div_term: Tensor = torch.exp(
|
| 76 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000) / d_model)
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# (T_max, D)
|
| 80 |
+
pe: Tensor = torch.zeros(max_seq_len, d_model)
|
| 81 |
+
|
| 82 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 83 |
+
|
| 84 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 85 |
+
|
| 86 |
+
# (T_max D) -> (1, T_max, D)
|
| 87 |
+
pe = pe.unsqueeze(0)
|
| 88 |
+
|
| 89 |
+
self.register_buffer("pe", pe)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]:
|
| 92 |
+
"""
|
| 93 |
+
Adds positional encoding to the input embeddings and applies dropout.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
x (Tensor): Input tensor (token embeddings, already scaled).
|
| 97 |
+
Shape (B, T, D).
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
Tensor: Output tensor with positional information and dropout.
|
| 101 |
+
Shape (B, T, D).
|
| 102 |
+
"""
|
| 103 |
+
x = x + self.pe[:, : x.size(1), :]
|
| 104 |
+
|
| 105 |
+
return self.dropout(x)
|
src/engine.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 5 |
+
from torchmetrics.text import BLEUScore, SacreBLEUScore
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
import config
|
| 8 |
+
from src import model, utils
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
TGT_VOCAB_SIZE: int = config.VOCAB_SIZE
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def train_one_epoch(
|
| 15 |
+
model: model.Transformer,
|
| 16 |
+
dataloader: DataLoader,
|
| 17 |
+
optimizer: torch.optim.Optimizer,
|
| 18 |
+
criterion: nn.Module,
|
| 19 |
+
scheduler: torch.optim.lr_scheduler.LambdaLR,
|
| 20 |
+
device: torch.device,
|
| 21 |
+
logger=None,
|
| 22 |
+
) -> float:
|
| 23 |
+
"""
|
| 24 |
+
Runs a single training epoch.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
model: The Transformer model.
|
| 28 |
+
dataloader: The training DataLoader.
|
| 29 |
+
optimizer: The optimizer.
|
| 30 |
+
criterion: The loss function (e.g., CrossEntropyLoss).
|
| 31 |
+
device: The device to run on (e.g., 'cuda').
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
The average training loss for the epoch.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Set model to training mode
|
| 38 |
+
# This enables dropout, etc.
|
| 39 |
+
model.train()
|
| 40 |
+
|
| 41 |
+
total_loss = 0.0
|
| 42 |
+
|
| 43 |
+
# Use tqdm for a progress bar
|
| 44 |
+
progress_bar = tqdm(dataloader, desc="Training", leave=False)
|
| 45 |
+
batch_idx: int = 0
|
| 46 |
+
|
| 47 |
+
for batch in progress_bar:
|
| 48 |
+
batch_idx += 1
|
| 49 |
+
|
| 50 |
+
# 1. Move batch to device (GPU)
|
| 51 |
+
# We define a helper for this
|
| 52 |
+
batch_gpu = {
|
| 53 |
+
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# 2. Zero gradients before forward pass
|
| 57 |
+
optimizer.zero_grad()
|
| 58 |
+
|
| 59 |
+
# 3. Forward pass
|
| 60 |
+
# Get inputs for the model (as defined in Transformer.forward)
|
| 61 |
+
logits = model(
|
| 62 |
+
src=batch_gpu["src_ids"],
|
| 63 |
+
tgt=batch_gpu["tgt_input_ids"],
|
| 64 |
+
src_mask=batch_gpu["src_mask"],
|
| 65 |
+
tgt_mask=batch_gpu["tgt_mask"],
|
| 66 |
+
) # Shape: (B, T_tgt, vocab_size)
|
| 67 |
+
|
| 68 |
+
# 4. Calculate loss
|
| 69 |
+
# CrossEntropyLoss expects (N, C) and (N,)
|
| 70 |
+
# We must reshape logits and labels
|
| 71 |
+
# Logits: (B, T_tgt, C) -> (B * T_tgt, C)
|
| 72 |
+
# Labels: (B, T_tgt) -> (B * T_tgt)
|
| 73 |
+
loss = criterion(logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1))
|
| 74 |
+
|
| 75 |
+
# 5. Backward pass (compute gradients)
|
| 76 |
+
loss.backward()
|
| 77 |
+
|
| 78 |
+
# 6. Gradient Clipping (from paper)
|
| 79 |
+
# Helps prevent exploding gradients. '1.0' is a common value.
|
| 80 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 81 |
+
|
| 82 |
+
# 7. Update weights
|
| 83 |
+
optimizer.step()
|
| 84 |
+
|
| 85 |
+
# 8. Update learning rate scheduler if used
|
| 86 |
+
scheduler.step()
|
| 87 |
+
|
| 88 |
+
# 9. Update stats
|
| 89 |
+
total_loss += loss.item()
|
| 90 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 91 |
+
|
| 92 |
+
# 10. Log metrics
|
| 93 |
+
if logger and batch_idx % 100 == 0:
|
| 94 |
+
logger.log(
|
| 95 |
+
{
|
| 96 |
+
"train/batch_loss": loss.item(),
|
| 97 |
+
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Return average loss for the epoch
|
| 102 |
+
return total_loss / len(dataloader)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def validate_one_epoch(
|
| 106 |
+
model: model.Transformer,
|
| 107 |
+
dataloader: DataLoader,
|
| 108 |
+
criterion: nn.Module,
|
| 109 |
+
device: torch.device,
|
| 110 |
+
) -> float:
|
| 111 |
+
"""
|
| 112 |
+
Runs a single validation epoch.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
model: The Transformer model.
|
| 116 |
+
dataloader: The validation DataLoader.
|
| 117 |
+
criterion: The loss function (e.g., CrossEntropyLoss).
|
| 118 |
+
device: The device to run on (e.g., 'cuda').
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
The average validation loss for the epoch.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
# Set model to evaluation mode
|
| 125 |
+
# This disables dropout.
|
| 126 |
+
model.eval()
|
| 127 |
+
|
| 128 |
+
total_loss = 0.0
|
| 129 |
+
|
| 130 |
+
# Use tqdm for a progress bar
|
| 131 |
+
progress_bar = tqdm(dataloader, desc="Validating", leave=False)
|
| 132 |
+
|
| 133 |
+
# Disable gradient computation
|
| 134 |
+
# This saves VRAM and speeds up inference.
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
for batch in progress_bar:
|
| 137 |
+
# 1. Move batch to device (GPU)
|
| 138 |
+
batch_gpu = {
|
| 139 |
+
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
# 2. Forward pass
|
| 143 |
+
logits = model(
|
| 144 |
+
src=batch_gpu["src_ids"],
|
| 145 |
+
tgt=batch_gpu["tgt_input_ids"],
|
| 146 |
+
src_mask=batch_gpu["src_mask"],
|
| 147 |
+
tgt_mask=batch_gpu["tgt_mask"],
|
| 148 |
+
) # Shape: (B, T_tgt, vocab_size)
|
| 149 |
+
|
| 150 |
+
# 3. Calculate loss
|
| 151 |
+
# (Use the same reshaping as in training for consistency)
|
| 152 |
+
loss = criterion(
|
| 153 |
+
logits.view(-1, TGT_VOCAB_SIZE), batch_gpu["labels"].view(-1)
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# 4. Update stats
|
| 157 |
+
total_loss += loss.item()
|
| 158 |
+
progress_bar.set_postfix(loss=loss.item())
|
| 159 |
+
|
| 160 |
+
# Return average loss for the epoch
|
| 161 |
+
return total_loss / len(dataloader)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def evaluate_model(
|
| 165 |
+
model: model.Transformer,
|
| 166 |
+
dataloader: DataLoader,
|
| 167 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 168 |
+
device: torch.device,
|
| 169 |
+
table=None,
|
| 170 |
+
) -> tuple[float, float]:
|
| 171 |
+
"""
|
| 172 |
+
Runs final evaluation on the test set using Beam Search
|
| 173 |
+
and calculates the SacreBLEU score.
|
| 174 |
+
"""
|
| 175 |
+
print("\n--- Starting Evaluation (BLEU + SacreBLEU) ---")
|
| 176 |
+
|
| 177 |
+
# Set model to evaluation mode
|
| 178 |
+
# This disables dropout.
|
| 179 |
+
model.eval()
|
| 180 |
+
|
| 181 |
+
all_predicted_strings = []
|
| 182 |
+
all_expected_strings = []
|
| 183 |
+
|
| 184 |
+
# --- No gradients needed ---
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
for batch in tqdm(dataloader, desc="Evaluating"):
|
| 187 |
+
|
| 188 |
+
batch_gpu = {
|
| 189 |
+
k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
src_ids = batch_gpu["src_ids"]
|
| 193 |
+
src_mask = batch_gpu["src_mask"]
|
| 194 |
+
expected_ids = batch_gpu["labels"] # (B, T_tgt) [on GPU]
|
| 195 |
+
|
| 196 |
+
B = src_ids.size(0)
|
| 197 |
+
|
| 198 |
+
# --- Handle 2D Expected IDs) ---
|
| 199 |
+
batch_expected_strings = []
|
| 200 |
+
|
| 201 |
+
# Convert 2D GPU Tensor -> 2D CPU List
|
| 202 |
+
expected_id_lists = expected_ids.cpu().tolist()
|
| 203 |
+
|
| 204 |
+
# Now we iterate over the CPU list
|
| 205 |
+
for id_list in expected_id_lists:
|
| 206 |
+
# id_list is a 1D Python list (e.g., [70, 950, 7, 3])
|
| 207 |
+
# This call is now safe
|
| 208 |
+
token_list = tokenizer.convert_ids_to_tokens(id_list)
|
| 209 |
+
batch_expected_strings.append(
|
| 210 |
+
utils.filter_and_detokenize(token_list, skip_special=True)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# --- Generate (decode) one sentence at a time ---
|
| 214 |
+
batch_predicted_strings = []
|
| 215 |
+
for i in tqdm(range(B), desc="Decoding Batch", leave=False):
|
| 216 |
+
src_sentence = src_ids[i].unsqueeze(0)
|
| 217 |
+
src_sentence_mask = src_mask[i].unsqueeze(0)
|
| 218 |
+
|
| 219 |
+
# (predicted_ids is 1D Tensor [T_out] on GPU)
|
| 220 |
+
predicted_ids = utils.greedy_decode_sentence(
|
| 221 |
+
model,
|
| 222 |
+
src_sentence,
|
| 223 |
+
src_sentence_mask,
|
| 224 |
+
max_len=config.MAX_SEQ_LEN,
|
| 225 |
+
sos_token_id=config.SOS_TOKEN_ID,
|
| 226 |
+
eos_token_id=config.EOS_TOKEN_ID,
|
| 227 |
+
device=device,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Convert 1D GPU Tensor -> 1D CPU List
|
| 231 |
+
predicted_id_list = predicted_ids.cpu().tolist()
|
| 232 |
+
|
| 233 |
+
# This call is now safe
|
| 234 |
+
predicted_token_list = tokenizer.convert_ids_to_tokens(
|
| 235 |
+
predicted_id_list
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
decoded_str = utils.filter_and_detokenize(
|
| 239 |
+
predicted_token_list, skip_special=True
|
| 240 |
+
)
|
| 241 |
+
batch_predicted_strings.append(decoded_str)
|
| 242 |
+
|
| 243 |
+
# --- Store strings for final metric calculation ---
|
| 244 |
+
all_predicted_strings.extend(batch_predicted_strings)
|
| 245 |
+
all_expected_strings.extend([[s] for s in batch_expected_strings])
|
| 246 |
+
|
| 247 |
+
bleu_metric = BLEUScore(n_gram=4, smooth=True).to(config.DEVICE)
|
| 248 |
+
sacrebleu_metric = SacreBLEUScore(
|
| 249 |
+
n_gram=4, smooth=True, tokenize="intl", lowercase=False
|
| 250 |
+
).to(config.DEVICE)
|
| 251 |
+
|
| 252 |
+
# --- 5. Calculate final score ---
|
| 253 |
+
print("\nCalculating final BLEU score...")
|
| 254 |
+
final_bleu = bleu_metric(all_predicted_strings, all_expected_strings)
|
| 255 |
+
|
| 256 |
+
# print(f"\n========================================")
|
| 257 |
+
# print(f"🎉 FINAL BLEU SCORE (Evaluation Set): {final_bleu.item() * 100:.4f}%")
|
| 258 |
+
# print(f"========================================")
|
| 259 |
+
|
| 260 |
+
print("\nCalculating final SacreBLEU score...")
|
| 261 |
+
final_sacrebleu = sacrebleu_metric(all_predicted_strings, all_expected_strings)
|
| 262 |
+
|
| 263 |
+
# print(f"\n========================================")
|
| 264 |
+
# print(
|
| 265 |
+
# f"🎉 FINAL SacreBLEU SCORE (Evaluation Set): {final_sacrebleu.item() * 100:.4f}%"
|
| 266 |
+
# )
|
| 267 |
+
# print(f"========================================")
|
| 268 |
+
|
| 269 |
+
# --- Show some examples ---
|
| 270 |
+
print("\n--- Translation Examples (Pred vs Exp) ---")
|
| 271 |
+
for i in range(min(5, len(all_predicted_strings))):
|
| 272 |
+
print(f" PRED: {all_predicted_strings[i]}")
|
| 273 |
+
print(f" EXP: {all_expected_strings[i][0]}")
|
| 274 |
+
print(" ---")
|
| 275 |
+
|
| 276 |
+
table.add_data(all_expected_strings[i][0], all_predicted_strings[i])
|
| 277 |
+
|
| 278 |
+
return final_bleu.item() * 100, final_sacrebleu.item() * 100
|
src/layers.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from jaxtyping import Bool, Float
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MultiHeadAttention(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
Terminology (jaxtyping):
|
| 10 |
+
B: batch_size
|
| 11 |
+
T_q: target sequence length (query)
|
| 12 |
+
T_k: source sequence length (key/value)
|
| 13 |
+
D: d_model (model dimension)
|
| 14 |
+
H: n_heads (number of heads)
|
| 15 |
+
d_k: dimension of each head (d_model / n_heads)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, d_model: int, n_heads: int) -> None:
|
| 19 |
+
super().__init__()
|
| 20 |
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
| 21 |
+
|
| 22 |
+
self.d_model: int = d_model
|
| 23 |
+
self.n_heads: int = n_heads
|
| 24 |
+
self.d_k: int = d_model // n_heads
|
| 25 |
+
|
| 26 |
+
self.w_q: nn.Linear = nn.Linear(d_model, d_model, bias=False)
|
| 27 |
+
self.w_k: nn.Linear = nn.Linear(d_model, d_model, bias=False)
|
| 28 |
+
self.w_v: nn.Linear = nn.Linear(d_model, d_model, bias=False)
|
| 29 |
+
self.w_o: nn.Linear = nn.Linear(d_model, d_model, bias=False)
|
| 30 |
+
|
| 31 |
+
self.attention_weights: Tensor | None = None
|
| 32 |
+
|
| 33 |
+
@staticmethod
|
| 34 |
+
def attention(
|
| 35 |
+
query: Float[Tensor, "B H T_q d_k"],
|
| 36 |
+
key: Float[Tensor, "B H T_k d_k"],
|
| 37 |
+
value: Float[Tensor, "B H T_k d_k"],
|
| 38 |
+
mask: Bool[Tensor, "... 1 T_q T_k"] | None,
|
| 39 |
+
) -> tuple[Float[Tensor, "B H T_q d_k"], Float[Tensor, "B H T_q T_k"]]:
|
| 40 |
+
"""
|
| 41 |
+
Static method for Scaled Dot-Product Attention calculation.
|
| 42 |
+
This is pure, stateless logic, making it easy to test.
|
| 43 |
+
(Ref: "Attention Is All You Need", Equation 1)
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
query (Tensor): Query tensor
|
| 47 |
+
key (Tensor): Key tensor
|
| 48 |
+
value (Tensor): Value tensor
|
| 49 |
+
mask (Tensor | None): Optional mask (for padding or look-ahead).
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
tuple[Tensor, Tensor]:
|
| 53 |
+
- context_vector: The output of the attention mechanism.
|
| 54 |
+
- attention_weights: The softmax-normalized attention weights.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
d_k: int = query.shape[-1]
|
| 58 |
+
|
| 59 |
+
# (B, H, T_q, d_k) @ (B, H, d_k, T_k) -> (B, H, T_q, T_k)
|
| 60 |
+
attention_scores: Tensor = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
|
| 61 |
+
|
| 62 |
+
if mask is not None:
|
| 63 |
+
attention_scores = attention_scores.masked_fill(
|
| 64 |
+
mask == 0, value=float("-inf")
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
attention_weights: Tensor = attention_scores.softmax(dim=-1)
|
| 68 |
+
|
| 69 |
+
# (B, H, T_q, T_k) @ (B, H, T_k, d_k) -> (B, H, T_q, d_k)
|
| 70 |
+
context_vector: Tensor = attention_weights @ value
|
| 71 |
+
|
| 72 |
+
return context_vector, attention_weights
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
q: Float[Tensor, "B T_q D"],
|
| 77 |
+
k: Float[Tensor, "B T_k D"],
|
| 78 |
+
v: Float[Tensor, "B T_k D"],
|
| 79 |
+
mask: Bool[Tensor, "... 1 T_q T_k"] | None = None, # Optional mask
|
| 80 |
+
) -> Float[Tensor, "B T_q D"]:
|
| 81 |
+
"""
|
| 82 |
+
Forward pass for Multi-Head Attention.
|
| 83 |
+
|
| 84 |
+
In Self-Attention (Encoder), q, k, and v are all the same tensor.
|
| 85 |
+
In Cross-Attention (Decoder), q comes from the Decoder, while k and v
|
| 86 |
+
come from the Encoder's output.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
q: Query tensor
|
| 90 |
+
k: Key tensor
|
| 91 |
+
v: Value tensor
|
| 92 |
+
mask: Optional mask to apply (padding or look-ahead)
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
The context vector after multi-head attention and output projection.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
B, T_q, _ = q.shape
|
| 99 |
+
_, T_k, _ = k.shape # T_k == T_v
|
| 100 |
+
|
| 101 |
+
# (B, T, D) -> (B, T, D)
|
| 102 |
+
Q: Tensor = self.w_q(q)
|
| 103 |
+
K: Tensor = self.w_k(k)
|
| 104 |
+
V: Tensor = self.w_v(v)
|
| 105 |
+
|
| 106 |
+
# (B, T, D) -> (B, T, H, d_k) -> (B, H, T, d_k)
|
| 107 |
+
Q = Q.view(B, T_q, self.n_heads, self.d_k).transpose(1, 2)
|
| 108 |
+
K = K.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)
|
| 109 |
+
V = V.view(B, T_k, self.n_heads, self.d_k).transpose(1, 2)
|
| 110 |
+
|
| 111 |
+
context_vector, self.attention_weights = self.attention(Q, K, V, mask)
|
| 112 |
+
|
| 113 |
+
# (B, H, T_q, d_k) -> (B, T_q, H, d_k)
|
| 114 |
+
context_vector = context_vector.transpose(1, 2).contiguous()
|
| 115 |
+
|
| 116 |
+
# (B, T_q, H, d_k) -> (B, T_q, D)
|
| 117 |
+
context_vector = context_vector.view(B, T_q, self.d_model)
|
| 118 |
+
|
| 119 |
+
# (B, T_q, D) -> (B, T_q, D)
|
| 120 |
+
output: Tensor = self.w_o(context_vector)
|
| 121 |
+
|
| 122 |
+
return output
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class PositionwiseFeedForward(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
Implements the Position-wise Feed-Forward Network (FFN) sublayer.
|
| 128 |
+
(Ref: "Attention Is All You Need", Section 3.3)
|
| 129 |
+
|
| 130 |
+
This is a two-layer MLP (Multi-Layer Perceptron) applied independently
|
| 131 |
+
to each position in the sequence.
|
| 132 |
+
|
| 133 |
+
FFN(x) = max(0, x * W_1 + b_1) * W_2 + b_2
|
| 134 |
+
(Or using ReLU activation)
|
| 135 |
+
|
| 136 |
+
Terminology (jaxtyping):
|
| 137 |
+
B: batch_size
|
| 138 |
+
T: seq_len (context_length)
|
| 139 |
+
D: d_model (model dimension)
|
| 140 |
+
D_FF: d_ff (inner feed-forward dimension)
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(self, d_model: int, d_ff: int) -> None:
|
| 144 |
+
"""
|
| 145 |
+
Initializes the FFN.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
d_model (int): Dimension of the model (e.g., 512).
|
| 149 |
+
d_ff (int): Inner dimension of the FFN (e.g., 2048).
|
| 150 |
+
Paper suggests d_ff = 4 * d_model.
|
| 151 |
+
dropout (float): Dropout probability (applied *before* the
|
| 152 |
+
second linear layer in some implementations,
|
| 153 |
+
or as part of ResidualConnection).
|
| 154 |
+
"""
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
# (B, T, D) -> (B, T, D_FF)
|
| 158 |
+
self.linear_1: nn.Linear = nn.Linear(d_model, d_ff)
|
| 159 |
+
|
| 160 |
+
self.activation: nn.ReLU = nn.ReLU()
|
| 161 |
+
|
| 162 |
+
# (B, T, D_FF) -> (B, T, D)
|
| 163 |
+
self.linear_2: nn.Linear = nn.Linear(d_ff, d_model)
|
| 164 |
+
|
| 165 |
+
def forward(self, x: Float[Tensor, "B T D"]) -> Float[Tensor, "B T D"]:
|
| 166 |
+
"""
|
| 167 |
+
Forward pass for the FFN.
|
| 168 |
+
Applies two linear transformations with a ReLU activation in between.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
x: Input tensor from the previous sublayer
|
| 172 |
+
(e.g., MultiHeadAttention output).
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
Output tensor of the same shape.
|
| 176 |
+
"""
|
| 177 |
+
# (B, T, D) -> (B, T, D_FF)
|
| 178 |
+
x = self.linear_1(x)
|
| 179 |
+
|
| 180 |
+
# (B, T, D_FF) -> (B, T, D_FF)
|
| 181 |
+
x = self.activation(x)
|
| 182 |
+
|
| 183 |
+
# (B, T, D_FF) -> (B, T, D)
|
| 184 |
+
x = self.linear_2(x)
|
| 185 |
+
|
| 186 |
+
return x
|
src/model.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from safetensors.torch import load_model
|
| 5 |
+
from jaxtyping import Bool, Int, Float
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from embedding import InputEmbeddings, PositionalEncoding
|
| 8 |
+
from modules import Encoder, Decoder
|
| 9 |
+
import config
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Generator(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Implements the final Linear (Projection) layer and Softmax.
|
| 15 |
+
|
| 16 |
+
This module takes the final output of the Decoder stack (B, T, D)
|
| 17 |
+
and projects it onto the vocabulary space (B, T, vocab_size)
|
| 18 |
+
to produce the logits.
|
| 19 |
+
|
| 20 |
+
(This layer's weights can be tied with the
|
| 21 |
+
target embedding layer, which we will handle in the main
|
| 22 |
+
'Transformer' model class).
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
| 26 |
+
"""
|
| 27 |
+
Initializes the Generator (Output Projection) layer.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
d_model (int): The dimension of the model (D).
|
| 31 |
+
vocab_size (int): The size of the target vocabulary.
|
| 32 |
+
"""
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self.proj: nn.Linear = nn.Linear(d_model, vocab_size, bias=False)
|
| 36 |
+
|
| 37 |
+
def forward(
|
| 38 |
+
self, x: Float[Tensor, "B T_tgt D"]
|
| 39 |
+
) -> Float[Tensor, "B T_tgt vocab_size"]:
|
| 40 |
+
"""
|
| 41 |
+
Forward pass for the Generator.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
x (Tensor): The final output tensor from the Decoder stack.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tensor: The output logits over the vocabulary.
|
| 48 |
+
"""
|
| 49 |
+
# (B, T_tgt, D) -> (B, T_tgt, vocab_size)
|
| 50 |
+
logits = self.proj(x)
|
| 51 |
+
return logits
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Transformer(nn.Module):
|
| 55 |
+
"""
|
| 56 |
+
The main Transformer model architecture, combining the Encoder
|
| 57 |
+
and Decoder stacks, as described in "Attention Is All You Need".
|
| 58 |
+
|
| 59 |
+
This implementation follows modern best practices (Pre-LN) and
|
| 60 |
+
is designed for a sequence-to-sequence task (e.g., translation).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
src_vocab_size: int,
|
| 66 |
+
tgt_vocab_size: int,
|
| 67 |
+
d_model: int,
|
| 68 |
+
n_heads: int,
|
| 69 |
+
n_layers: int, # N=6 in the paper
|
| 70 |
+
d_ff: int,
|
| 71 |
+
dropout: float = 0.1,
|
| 72 |
+
max_seq_len: int = 512, # Max length for positional encoding
|
| 73 |
+
) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Initializes the full Transformer model.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
src_vocab_size (int): Vocabulary size for the source language.
|
| 79 |
+
tgt_vocab_size (int): Vocabulary size for the target language.
|
| 80 |
+
d_model (int): The dimension of the model (D).
|
| 81 |
+
n_heads (int): The number of attention heads (H).
|
| 82 |
+
n_layers (int): The number of Encoder/Decoder layers (N).
|
| 83 |
+
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
|
| 84 |
+
dropout (float): The dropout rate.
|
| 85 |
+
max_seq_len (int): The maximum sequence length for positional encoding.
|
| 86 |
+
"""
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.d_model = d_model
|
| 90 |
+
|
| 91 |
+
# --- 1. Source (Encoder) Embeddings ---
|
| 92 |
+
# We create two separate embedding layers
|
| 93 |
+
self.src_embed: InputEmbeddings = InputEmbeddings(d_model, src_vocab_size)
|
| 94 |
+
|
| 95 |
+
# --- 2. Target (Decoder) Embeddings ---
|
| 96 |
+
self.tgt_embed: InputEmbeddings = InputEmbeddings(d_model, tgt_vocab_size)
|
| 97 |
+
|
| 98 |
+
# --- 3. Positional Encoding ---
|
| 99 |
+
# We use "one" PositionalEncoding module
|
| 100 |
+
# and share it for both source and target.
|
| 101 |
+
self.pos_enc: PositionalEncoding = PositionalEncoding(
|
| 102 |
+
d_model, max_seq_len, dropout
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# --- 4. Encoder Stack ---
|
| 106 |
+
self.encoder: Encoder = Encoder(d_model, n_heads, d_ff, n_layers, dropout)
|
| 107 |
+
|
| 108 |
+
# --- 5. Decoder Stack ---
|
| 109 |
+
self.decoder: Decoder = Decoder(d_model, n_heads, d_ff, n_layers, dropout)
|
| 110 |
+
|
| 111 |
+
# --- 6. Final Output Projection (Generator) ---
|
| 112 |
+
self.generator: Generator = Generator(d_model, tgt_vocab_size)
|
| 113 |
+
|
| 114 |
+
# --- Weight Typing ---
|
| 115 |
+
# We tie the weights of the target embedding and the generator.
|
| 116 |
+
# This saves parameters and improves performance.
|
| 117 |
+
self.generator.proj.weight = self.tgt_embed.token_emb.weight
|
| 118 |
+
|
| 119 |
+
# --- Initialize weights ---
|
| 120 |
+
# This is crucial for stable training.
|
| 121 |
+
self.apply(self._init_weights)
|
| 122 |
+
|
| 123 |
+
def _init_weights(self, module: nn.Module):
|
| 124 |
+
"""
|
| 125 |
+
Applies Xavier/Glorot uniform initialization to linear layers.
|
| 126 |
+
This is a common and effective initialization strategy.
|
| 127 |
+
"""
|
| 128 |
+
if isinstance(module, nn.Linear):
|
| 129 |
+
nn.init.xavier_uniform_(module.weight)
|
| 130 |
+
|
| 131 |
+
if module.bias is not None:
|
| 132 |
+
nn.init.constant_(module.bias, 0)
|
| 133 |
+
|
| 134 |
+
elif isinstance(module, nn.Embedding):
|
| 135 |
+
# Initialize embeddings (e.g., from a normal distribution)
|
| 136 |
+
nn.init.normal_(module.weight, mean=0, std=self.d_model**-0.5)
|
| 137 |
+
|
| 138 |
+
def forward(
|
| 139 |
+
self,
|
| 140 |
+
src: Int[Tensor, "B T_src"], # Source token IDs (e.g., English)
|
| 141 |
+
tgt: Int[Tensor, "B T_tgt"], # Target token IDs (e.g., Vietnamese)
|
| 142 |
+
src_mask: Bool[Tensor, "B 1 1 T_src"], # Source padding mask
|
| 143 |
+
tgt_mask: Bool[Tensor, "B 1 T_tgt T_tgt"], # Target combined mask
|
| 144 |
+
) -> Float[Tensor, "B T_tgt vocab_size"]:
|
| 145 |
+
"""
|
| 146 |
+
Defines the main forward pass of the Transformer model.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
src (Tensor): Source sequence token IDs.
|
| 150 |
+
tgt (Tensor): Target sequence token IDs (shifted right).
|
| 151 |
+
src_mask (Tensor): Padding mask for the source sequence.
|
| 152 |
+
tgt_mask (Tensor): Combined padding and look-ahead mask
|
| 153 |
+
for the target sequence.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Tensor: The output logits from the model (B, T_tgt, vocab_size).
|
| 157 |
+
"""
|
| 158 |
+
# 1. Encode the source sequence
|
| 159 |
+
# (B, T_src) -> (B, T_scr, D)
|
| 160 |
+
src_embeded = self.src_embed(src)
|
| 161 |
+
src_with_pos = self.pos_enc(src_embeded)
|
| 162 |
+
|
| 163 |
+
# (B, T_src, D) -> (B, T_src, D)
|
| 164 |
+
# This 'memory' will be used by every DecoderLayer
|
| 165 |
+
enc_output: Tensor = self.encoder(src_with_pos, src_mask)
|
| 166 |
+
|
| 167 |
+
# 2. Decode the target sequence
|
| 168 |
+
# (B, T_tgt) -> (B, T_tgt, D)
|
| 169 |
+
tgt_embeded = self.tgt_embed(tgt)
|
| 170 |
+
tgt_with_pos = self.pos_enc(tgt_embeded)
|
| 171 |
+
|
| 172 |
+
# (B, T_tgt, D) -> (B, T_tgt, D)
|
| 173 |
+
dec_output: Tensor = self.decoder(tgt_with_pos, enc_output, src_mask, tgt_mask)
|
| 174 |
+
|
| 175 |
+
# 3. Generate final logits
|
| 176 |
+
# (B, T_tgt, D) -> (B, T_tgt, vocab_size)
|
| 177 |
+
logits: Tensor = self.generator(dec_output)
|
| 178 |
+
|
| 179 |
+
return logits
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def load_trained_model(
|
| 183 |
+
config_obj, checkpoint_path, device: torch.device
|
| 184 |
+
) -> Transformer:
|
| 185 |
+
print("Downloading safetensors from Hub...")
|
| 186 |
+
model_path = hf_hub_download(repo_id=config.REPO_ID, filename=config.FILENAME)
|
| 187 |
+
|
| 188 |
+
print("Instantiating the Transformer model...")
|
| 189 |
+
model = Transformer(
|
| 190 |
+
src_vocab_size=config_obj.VOCAB_SIZE,
|
| 191 |
+
tgt_vocab_size=config_obj.VOCAB_SIZE,
|
| 192 |
+
d_model=config_obj.D_MODEL,
|
| 193 |
+
n_heads=config_obj.N_HEADS,
|
| 194 |
+
n_layers=config_obj.N_LAYERS,
|
| 195 |
+
d_ff=config_obj.D_FF,
|
| 196 |
+
dropout=config_obj.DROPOUT,
|
| 197 |
+
max_seq_len=config_obj.MAX_SEQ_LEN,
|
| 198 |
+
).to(device)
|
| 199 |
+
|
| 200 |
+
# print(f"Loading model from: {checkpoint_path}")
|
| 201 |
+
# load_model(model, filename=checkpoint_path)
|
| 202 |
+
|
| 203 |
+
print(f"Loading model from: {model_path}")
|
| 204 |
+
load_model(model, filename=model_path)
|
| 205 |
+
|
| 206 |
+
print(f"Successfully loaded trained weights from {model_path}")
|
| 207 |
+
return model
|
src/modules.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Callable
|
| 4 |
+
from jaxtyping import Bool, Float
|
| 5 |
+
from layers import MultiHeadAttention, PositionwiseFeedForward
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ResidualConnection(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
Implements the (Pre-LN) Residual Connection module, which wraps a sublayer
|
| 11 |
+
(like MultiHeadAttention or FFN) with LayerNormalization and Dropout.
|
| 12 |
+
|
| 13 |
+
This is the modern "best practice" used in models like GPT-2, which is
|
| 14 |
+
more stable than the original Post-LN design in "Attention Is All You Need".
|
| 15 |
+
|
| 16 |
+
Architecture: x = x + Dropout(Sublayer(LayerNorm(x)))
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
|
| 20 |
+
"""
|
| 21 |
+
Initializes the Residual Connection.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
d_model (int): The dimension of the model (D).
|
| 25 |
+
dropout (float): Dropout probability to apply to the sublayer output.
|
| 26 |
+
"""
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.dropout: nn.Dropout = nn.Dropout(dropout)
|
| 30 |
+
|
| 31 |
+
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
|
| 32 |
+
|
| 33 |
+
def forward(
|
| 34 |
+
self,
|
| 35 |
+
x: Float[Tensor, "B T D"],
|
| 36 |
+
sublayer: Callable[[Float[Tensor, "B T D"]], Float[Tensor, "B T D"]],
|
| 37 |
+
) -> Float[Tensor, "B T D"]:
|
| 38 |
+
"""
|
| 39 |
+
Forward pass for the Residual Connection.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (Tensor): The input tensor from the previous layer.
|
| 43 |
+
sublayer (Callable): The sublayer module (e.g., MHA or FFN)
|
| 44 |
+
to apply the connection to.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tensor: The output tensor after the residual connection.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
x_normed = self.norm(x)
|
| 51 |
+
|
| 52 |
+
sublayer_output = sublayer(x_normed)
|
| 53 |
+
|
| 54 |
+
dropout_output = self.dropout(sublayer_output)
|
| 55 |
+
|
| 56 |
+
return x + dropout_output
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class EncoderLayer(nn.Module):
|
| 60 |
+
"""
|
| 61 |
+
Implements one single Encoder Layer (or "Block") of the Transformer Encoder.
|
| 62 |
+
|
| 63 |
+
An Encoder Layer consists of two main sublayers:
|
| 64 |
+
1. A Multi-Head Self-Attention mechanism (MHA).
|
| 65 |
+
2. A Position-wise Feed-Forward Network (FFN).
|
| 66 |
+
|
| 67 |
+
Each sublayer is wrapped by a ResidualConnection (which includes
|
| 68 |
+
Pre-LayerNormalization and Dropout).
|
| 69 |
+
|
| 70 |
+
Architecture:
|
| 71 |
+
x -> Residual_1(x, MHA) -> x'
|
| 72 |
+
x' -> Residual_2(x', FFN) -> output
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
|
| 77 |
+
) -> None:
|
| 78 |
+
"""
|
| 79 |
+
Initializes the Encoder Layer.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
d_model (int): The dimension of the model (D).
|
| 83 |
+
n_heads (int): The number of attention heads (H).
|
| 84 |
+
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
|
| 85 |
+
dropout (float): The dropout rate for the residual connections.
|
| 86 |
+
"""
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
|
| 90 |
+
|
| 91 |
+
self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
|
| 92 |
+
d_model, d_ff
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
|
| 96 |
+
self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
|
| 97 |
+
|
| 98 |
+
def forward(
|
| 99 |
+
self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T_k"]
|
| 100 |
+
) -> Float[Tensor, "B T D"]:
|
| 101 |
+
"""
|
| 102 |
+
Forward pass for the Encoder Layer.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
x (Tensor): Input tensor from the previous layer or embedding.
|
| 106 |
+
src_mask (Tensor): The padding mask for the source sentence.
|
| 107 |
+
Shape (B, 1, 1, T_k) allows broadcasting
|
| 108 |
+
to (B, H, T_q, T_k).
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
Tensor: The output tensor of the Encoder Layer.
|
| 112 |
+
"""
|
| 113 |
+
x = self.residual_1(
|
| 114 |
+
x,
|
| 115 |
+
lambda x_normed: self.self_attn(
|
| 116 |
+
q=x_normed, k=x_normed, v=x_normed, mask=src_mask
|
| 117 |
+
),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
x = self.residual_2(x, self.feed_forward)
|
| 121 |
+
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Encoder(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
Implements the full Transformer Encoder, which is a stack of N
|
| 128 |
+
identical EncoderLayers.
|
| 129 |
+
|
| 130 |
+
This module takes the input embeddings + positional encodings and
|
| 131 |
+
processes them through N layers of self-attention and FFNs.
|
| 132 |
+
|
| 133 |
+
(Best Practice: Uses Pre-LN, so a final LayerNorm is applied
|
| 134 |
+
at the *end* of the stack, before passing to the Decoder).
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
def __init__(
|
| 138 |
+
self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
|
| 139 |
+
) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Initializes the Encoder stack.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
d_model (int): The dimension of the model (D).
|
| 145 |
+
n_heads (int): The number of attention heads (H).
|
| 146 |
+
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
|
| 147 |
+
n_layers (int): The number of EncoderLayer blocks to stack (N).
|
| 148 |
+
dropout (float): The dropout rate for the residual connections.
|
| 149 |
+
"""
|
| 150 |
+
super().__init__()
|
| 151 |
+
|
| 152 |
+
self.layers: nn.ModuleList = nn.ModuleList(
|
| 153 |
+
[EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self, x: Float[Tensor, "B T D"], src_mask: Bool[Tensor, "B 1 1 T"]
|
| 160 |
+
) -> Float[Tensor, "B T D"]:
|
| 161 |
+
"""
|
| 162 |
+
Forward pass for the entire Encoder stack.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
x (Tensor): Input tensor (usually token embeddings + pos encodings).
|
| 166 |
+
src_mask (Tensor): The padding mask for the source sentence.
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Tensor: The output of the final Encoder layer (the "context"
|
| 170 |
+
or "memory" for the Decoder).
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
for layer in self.layers:
|
| 174 |
+
x = layer(x, src_mask)
|
| 175 |
+
|
| 176 |
+
x = self.norm(x)
|
| 177 |
+
|
| 178 |
+
return x
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class DecoderLayer(nn.Module):
|
| 182 |
+
"""
|
| 183 |
+
Implements one single Decoder Layer (or "Block") of the Transformer Decoder.
|
| 184 |
+
|
| 185 |
+
A Decoder Layer consists of three main sublayers:
|
| 186 |
+
1. A Masked Multi-Head Self-Attention mechanism (MHA).
|
| 187 |
+
2. A Multi-Head Cross-Attention mechanism (MHA).
|
| 188 |
+
3. A Position-wise Feed-Forward Network (FFN).
|
| 189 |
+
|
| 190 |
+
Each sublayer is wrapped by a ResidualConnection (Pre-LN and Dropout).
|
| 191 |
+
|
| 192 |
+
Architecture:
|
| 193 |
+
x -> Residual_1(x, Masked_MHA) -> x'
|
| 194 |
+
x' -> Residual_2(x', Cross_MHA, enc_output) -> x''
|
| 195 |
+
x'' -> Residual_3(x'', FFN) -> output
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1
|
| 200 |
+
) -> None:
|
| 201 |
+
"""
|
| 202 |
+
Initializes the Decoder Layer.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
d_model (int): The dimension of the model (D).
|
| 206 |
+
n_heads (int): The number of attention heads (H).
|
| 207 |
+
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
|
| 208 |
+
dropout (float): The dropout rate for the residual connections.
|
| 209 |
+
"""
|
| 210 |
+
super().__init__()
|
| 211 |
+
|
| 212 |
+
self.self_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
|
| 213 |
+
|
| 214 |
+
self.cross_attn: MultiHeadAttention = MultiHeadAttention(d_model, n_heads)
|
| 215 |
+
|
| 216 |
+
self.feed_forward: PositionwiseFeedForward = PositionwiseFeedForward(
|
| 217 |
+
d_model, d_ff
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
self.residual_1: ResidualConnection = ResidualConnection(d_model, dropout)
|
| 221 |
+
self.residual_2: ResidualConnection = ResidualConnection(d_model, dropout)
|
| 222 |
+
self.residual_3: ResidualConnection = ResidualConnection(d_model, dropout)
|
| 223 |
+
|
| 224 |
+
def forward(
|
| 225 |
+
self,
|
| 226 |
+
x: Float[Tensor, "B T_tgt D"],
|
| 227 |
+
enc_output: Float[Tensor, "B T_src D"],
|
| 228 |
+
src_mask: Bool[Tensor, "B 1 1 T_src"],
|
| 229 |
+
tgt_mask: Bool[Tensor, "B 1 1 T_tgt"],
|
| 230 |
+
) -> Float[Tensor, "B T_tgt D"]:
|
| 231 |
+
"""
|
| 232 |
+
Forward pass for the Decoder Layer.
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
x (Tensor): Input tensor from the previous decoder layer.
|
| 236 |
+
enc_output (Tensor): The output tensor from the Encoder (K, V).
|
| 237 |
+
src_mask (Tensor): The padding mask for the source (Encoder) input.
|
| 238 |
+
tgt_mask (Tensor): The combined look-ahead and padding mask
|
| 239 |
+
for the target (Decoder) input.
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
Tensor: The output tensor of the Decoder Layer.
|
| 243 |
+
"""
|
| 244 |
+
x = self.residual_1(
|
| 245 |
+
x,
|
| 246 |
+
lambda x_normed: self.self_attn(
|
| 247 |
+
q=x_normed, k=x_normed, v=x_normed, mask=tgt_mask
|
| 248 |
+
),
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
x = self.residual_2(
|
| 252 |
+
x,
|
| 253 |
+
lambda x_normed: self.cross_attn(
|
| 254 |
+
q=x_normed, k=enc_output, v=enc_output, mask=src_mask
|
| 255 |
+
),
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
x = self.residual_3(x, self.feed_forward)
|
| 259 |
+
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class Decoder(nn.Module):
|
| 264 |
+
"""
|
| 265 |
+
Implements the full Transformer Decoder, which is a stack of N
|
| 266 |
+
identical DecoderLayers.
|
| 267 |
+
|
| 268 |
+
This module takes the target embeddings + positional encodings and
|
| 269 |
+
processes them through N layers of masked self-attention,
|
| 270 |
+
cross-attention, and FFNs.
|
| 271 |
+
|
| 272 |
+
(Best Practice: Uses Pre-LN, so a final LayerNorm is applied
|
| 273 |
+
at the *end* of the stack, before passing to the final Generator).
|
| 274 |
+
"""
|
| 275 |
+
|
| 276 |
+
def __init__(
|
| 277 |
+
self, d_model: int, n_heads: int, d_ff: int, n_layers: int, dropout: float = 0.1
|
| 278 |
+
) -> None:
|
| 279 |
+
"""
|
| 280 |
+
Initializes the Decoder stack.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
d_model (int): The dimension of the model (D).
|
| 284 |
+
n_heads (int): The number of attention heads (H).
|
| 285 |
+
d_ff (int): The inner dimension of the Feed-Forward Network (D_FF).
|
| 286 |
+
n_layers (int): The number of DecoderLayer blocks to stack (N).
|
| 287 |
+
dropout (float): The dropout rate for the residual connections.
|
| 288 |
+
"""
|
| 289 |
+
super().__init__()
|
| 290 |
+
|
| 291 |
+
self.layers: nn.ModuleList = nn.ModuleList(
|
| 292 |
+
[DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)]
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
self.norm: nn.LayerNorm = nn.LayerNorm(d_model)
|
| 296 |
+
|
| 297 |
+
def forward(
|
| 298 |
+
self,
|
| 299 |
+
x: Float[Tensor, "B T_tgt D"],
|
| 300 |
+
enc_output: Float[Tensor, "B T_src D"],
|
| 301 |
+
src_mask: Bool[Tensor, "B 1 1 T_src"],
|
| 302 |
+
tgt_mask: Bool[Tensor, "1 1 T_tgt T_tgt"],
|
| 303 |
+
) -> Float[Tensor, "B T_tgt D"]:
|
| 304 |
+
"""
|
| 305 |
+
Forward pass for the entire Decoder stack.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
x (Tensor): Input tensor for the target (embeddings + pos enc).
|
| 309 |
+
enc_output (Tensor): The output from the Encoder (K, V for cross-attn).
|
| 310 |
+
src_mask (Tensor): Padding mask for the source (Encoder) sequence.
|
| 311 |
+
tgt_mask (Tensor): Combined mask for the target (Decoder) sequence.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Tensor: The output of the final Decoder layer, ready for the
|
| 315 |
+
final projection (Generator).
|
| 316 |
+
"""
|
| 317 |
+
|
| 318 |
+
for layer in self.layers:
|
| 319 |
+
x = layer(x, enc_output, src_mask, tgt_mask)
|
| 320 |
+
|
| 321 |
+
x = self.norm(x)
|
| 322 |
+
|
| 323 |
+
return x
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,178 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
import time
|
| 3 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
import config
|
| 6 |
+
import model
|
| 7 |
+
import utils
|
| 8 |
|
| 9 |
+
|
| 10 |
+
# ==========================================
|
| 11 |
+
# 1. ASSUMPTIONS
|
| 12 |
+
# ==========================================
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@st.cache_resource
|
| 16 |
+
def load_artifacts():
|
| 17 |
+
tokenizer: PreTrainedTokenizerFast = None
|
| 18 |
+
transformer_model: model.Transformer = None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
tok_path = hf_hub_download(
|
| 22 |
+
repo_id=config.REPO_ID, filename="iwslt_en-vi_tokenizer_32k.json"
|
| 23 |
+
)
|
| 24 |
+
tokenizer = utils.load_tokenizer(tok_path)
|
| 25 |
+
|
| 26 |
+
print("Loading model for inference...")
|
| 27 |
+
transformer_model = model.load_trained_model(
|
| 28 |
+
config, config.MODEL_SAVE_PATH, config.DEVICE
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(
|
| 33 |
+
f"Warning: Could not load model. Using RANDOMLY initialized model. Error: {e}"
|
| 34 |
+
)
|
| 35 |
+
print(" (Translations will be gibberish)")
|
| 36 |
+
|
| 37 |
+
return transformer_model, tokenizer
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ==========================================
|
| 41 |
+
# 2. UI CONFIGURATION
|
| 42 |
+
# ==========================================
|
| 43 |
+
st.set_page_config(
|
| 44 |
+
page_title="En-Vi Translator | AttentionIsAllYouBuild",
|
| 45 |
+
page_icon="🤖",
|
| 46 |
+
layout="centered",
|
| 47 |
+
# layout="wide",
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Customize CSS to create beautiful interface
|
| 51 |
+
st.markdown(
|
| 52 |
+
"""
|
| 53 |
+
<style>
|
| 54 |
+
.main {
|
| 55 |
+
background-color: #f5f5f5;
|
| 56 |
+
}
|
| 57 |
+
.stTextArea textarea {
|
| 58 |
+
font-size: 16px;
|
| 59 |
+
}
|
| 60 |
+
.stButton button {
|
| 61 |
+
width: 100%;
|
| 62 |
+
background-color: #FF4B4B;
|
| 63 |
+
color: white;
|
| 64 |
+
font-weight: bold;
|
| 65 |
+
padding: 10px;
|
| 66 |
+
}
|
| 67 |
+
.result-box {
|
| 68 |
+
background-color: #ffffff;
|
| 69 |
+
padding: 20px;
|
| 70 |
+
border-radius: 10px;
|
| 71 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.1);
|
| 72 |
+
border-left: 5px solid #FF4B4B;
|
| 73 |
+
}
|
| 74 |
+
.source-text {
|
| 75 |
+
color: #666;
|
| 76 |
+
font-style: italic;
|
| 77 |
+
font-size: 14px;
|
| 78 |
+
margin-bottom: 5px;
|
| 79 |
+
}
|
| 80 |
+
.translated-text {
|
| 81 |
+
color: #333;
|
| 82 |
+
font-size: 20px;
|
| 83 |
+
font-weight: 600;
|
| 84 |
+
}
|
| 85 |
+
</style>
|
| 86 |
+
""",
|
| 87 |
+
unsafe_allow_html=True,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# ==========================================
|
| 91 |
+
# 3. MAIN APP LAYOUT
|
| 92 |
+
# ==========================================
|
| 93 |
+
|
| 94 |
+
# Header
|
| 95 |
+
st.title("🤖 AI Translator: English → Vietnamese")
|
| 96 |
+
st.markdown("### Project: *Attention Is All You Build*")
|
| 97 |
+
st.markdown("---")
|
| 98 |
+
|
| 99 |
+
# Sidebar (Thông tin thêm)
|
| 100 |
+
with st.sidebar:
|
| 101 |
+
st.header("ℹ️ Thông tin Model")
|
| 102 |
+
st.info(
|
| 103 |
+
"""
|
| 104 |
+
Đây là mô hình **Transformer (Encoder-Decoder)** được xây dựng "from scratch" bằng PyTorch.
|
| 105 |
+
|
| 106 |
+
- **Kiến trúc**: Pre-LN Transformer
|
| 107 |
+
- **Tokenizer**: BPE (32k vocab)
|
| 108 |
+
- **Inference**: Greedy
|
| 109 |
+
"""
|
| 110 |
+
)
|
| 111 |
+
st.write("Created by [Your Name]")
|
| 112 |
+
|
| 113 |
+
# Input Area
|
| 114 |
+
input_text = st.text_area(
|
| 115 |
+
label="Nhập câu tiếng Anh:",
|
| 116 |
+
placeholder="Example: Artificial intelligence is transforming the world...",
|
| 117 |
+
height=150,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ==========================================
|
| 121 |
+
# 4. INFERENCE LOGIC
|
| 122 |
+
# ==========================================
|
| 123 |
+
|
| 124 |
+
# Nút bấm Dịch
|
| 125 |
+
if st.button("Dịch sang Tiếng Việt (Translate)"):
|
| 126 |
+
if not input_text.strip():
|
| 127 |
+
st.warning("⚠️ Vui lòng nhập nội dung cần dịch!")
|
| 128 |
+
else:
|
| 129 |
+
# Hiển thị spinner trong khi model chạy
|
| 130 |
+
# Display spinner while model is running
|
| 131 |
+
with st.spinner("Wait a second... AI is thinking 🧠"):
|
| 132 |
+
try:
|
| 133 |
+
# Đo thời gian inference
|
| 134 |
+
start_time = time.time()
|
| 135 |
+
|
| 136 |
+
# --- Call translate function ---
|
| 137 |
+
transformer_model, tokenizer = load_artifacts()
|
| 138 |
+
|
| 139 |
+
if utils and transformer_model and tokenizer:
|
| 140 |
+
translation = utils.translate(
|
| 141 |
+
transformer_model,
|
| 142 |
+
tokenizer,
|
| 143 |
+
sentence_en=input_text,
|
| 144 |
+
device=config.DEVICE,
|
| 145 |
+
max_len=config.MAX_SEQ_LEN,
|
| 146 |
+
sos_token_id=config.SOS_TOKEN_ID,
|
| 147 |
+
eos_token_id=config.EOS_TOKEN_ID,
|
| 148 |
+
pad_token_id=config.PAD_TOKEN_ID,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
else:
|
| 152 |
+
# Mockup output
|
| 153 |
+
time.sleep(1) # Simulate latency
|
| 154 |
+
translation = "[DEMO OUTPUT] Hệ thống chưa load model thực tế. Đây là kết quả mẫu."
|
| 155 |
+
|
| 156 |
+
end_time = time.time()
|
| 157 |
+
inference_time = end_time - start_time
|
| 158 |
+
|
| 159 |
+
# --- Display Result ---
|
| 160 |
+
st.success(f"✅ Hoàn tất trong {inference_time:.2f}s")
|
| 161 |
+
|
| 162 |
+
st.markdown("### Kết quả:")
|
| 163 |
+
st.markdown(
|
| 164 |
+
f"""
|
| 165 |
+
<div class="result-box">
|
| 166 |
+
<div class="source-text">Original: {input_text}</div>
|
| 167 |
+
<div class="translated-text">{translation}</div>
|
| 168 |
+
</div>
|
| 169 |
+
""",
|
| 170 |
+
unsafe_allow_html=True,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
except Exception as e:
|
| 174 |
+
st.error(f"❌ Đã xảy ra lỗi trong quá trình dịch: {str(e)}")
|
| 175 |
+
|
| 176 |
+
# Footer
|
| 177 |
+
st.markdown("---")
|
| 178 |
+
st.caption("Powered by PyTorch & Streamlit")
|
src/tokenizer.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from datasets import Dataset
|
| 3 |
+
from tokenizers import (
|
| 4 |
+
Tokenizer,
|
| 5 |
+
models,
|
| 6 |
+
normalizers,
|
| 7 |
+
pre_tokenizers,
|
| 8 |
+
decoders,
|
| 9 |
+
trainers,
|
| 10 |
+
)
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
import wandb
|
| 13 |
+
from utils import get_raw_data
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DATA_PATH = Path(r"..\data\IWSLT-15-en-vi")
|
| 17 |
+
# TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
|
| 18 |
+
TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
|
| 19 |
+
TOKENIZER_SAVE_PATH = Path(r"..\artifacts\tokenizers") / TOKENIZER_NAME
|
| 20 |
+
|
| 21 |
+
# VOCAB_SIZE: int = 16_000
|
| 22 |
+
VOCAB_SIZE: int = 32_000
|
| 23 |
+
SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]
|
| 24 |
+
|
| 25 |
+
BATCH_SIZE_FOR_TOKENIZER: int = 10000
|
| 26 |
+
NUM_WORKERS: int = 8
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_training_corpus(dataset: Dataset, batch_size: int = 1000):
|
| 30 |
+
"""
|
| 31 |
+
A generator function to yield batches of text.
|
| 32 |
+
|
| 33 |
+
This implementation uses dataset.iter(batch_size=...), which is the
|
| 34 |
+
highly optimized, zero-copy Arrow iterator.
|
| 35 |
+
|
| 36 |
+
We then use list comprehensions to extract the 'en' and 'vi' strings
|
| 37 |
+
from the nested list of dictionaries returned by the iterator.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
# We iterate over the dataset in batches
|
| 41 |
+
# batch will be: {'translation': [list of 1000 dicts]}
|
| 42 |
+
for batch in dataset.iter(batch_size=batch_size):
|
| 43 |
+
|
| 44 |
+
# We must iterate through the list 'batch['translation']'
|
| 45 |
+
# to extract the individual strings.
|
| 46 |
+
|
| 47 |
+
# This list comprehension is fast and Pythonic.
|
| 48 |
+
en_strings: list[str] = [item["en"] for item in batch["translation"]]
|
| 49 |
+
vi_strings: list[str] = [item["vi"] for item in batch["translation"]]
|
| 50 |
+
|
| 51 |
+
# Yield the batch of strings (which the trainer expects)
|
| 52 |
+
yield en_strings
|
| 53 |
+
yield vi_strings
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def instantiate_tokenizer() -> Tokenizer:
|
| 57 |
+
# 1. Initialize an empty Tokenizer with a BPE model
|
| 58 |
+
tokenizer = Tokenizer(models.BPE(unk_token="[UNK]"))
|
| 59 |
+
|
| 60 |
+
# 2. Set up the normalizer and pre-tokenizer
|
| 61 |
+
# Normalizer: Cleans the text (e.g., Unicode, lowercase)
|
| 62 |
+
tokenizer.normalizer = normalizers.Sequence(
|
| 63 |
+
[
|
| 64 |
+
normalizers.NFKC(), # Unicode normalization
|
| 65 |
+
normalizers.Lowercase(), # Convert to lowercase
|
| 66 |
+
]
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Pre-tokenizer: Splits text into "words" (e.g., by space, punctuation)
|
| 70 |
+
# BPE will then learn to merge sub-words from these.
|
| 71 |
+
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()
|
| 72 |
+
|
| 73 |
+
# Decoder: Reconstructs the string from tokens
|
| 74 |
+
tokenizer.decoder = decoders.BPEDecoder()
|
| 75 |
+
|
| 76 |
+
print("Tokenizer (empty) initialized.")
|
| 77 |
+
return tokenizer
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def train_tokenizer():
|
| 81 |
+
# Initialize the BpeTrainer
|
| 82 |
+
trainer = trainers.BpeTrainer(vocab_size=VOCAB_SIZE, special_tokens=SPECIAL_TOKENS)
|
| 83 |
+
|
| 84 |
+
print("Tokenizer Trainer initialized.")
|
| 85 |
+
|
| 86 |
+
train_dataset = get_raw_data(DATA_PATH, for_tokenizer=True)
|
| 87 |
+
if not isinstance(train_dataset, Dataset):
|
| 88 |
+
train_dataset = Dataset.from_list(train_dataset)
|
| 89 |
+
print(f"Starting tokenizer training on {len(train_dataset)} pairs...")
|
| 90 |
+
|
| 91 |
+
# 1. Define the iterator AND batch size
|
| 92 |
+
text_iterator = get_training_corpus(
|
| 93 |
+
train_dataset,
|
| 94 |
+
batch_size=BATCH_SIZE_FOR_TOKENIZER,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# 2. Calculate total steps for the progress bar
|
| 98 |
+
total_steps = (len(train_dataset) // BATCH_SIZE_FOR_TOKENIZER) * 2
|
| 99 |
+
if total_steps == 0:
|
| 100 |
+
total_steps = 1 # (Avoid division by zero if dataset is tiny)
|
| 101 |
+
|
| 102 |
+
tokenizer: Tokenizer = instantiate_tokenizer()
|
| 103 |
+
# 3. Train with tqdm progress bar
|
| 104 |
+
try:
|
| 105 |
+
tokenizer.train_from_iterator(
|
| 106 |
+
tqdm(
|
| 107 |
+
text_iterator,
|
| 108 |
+
total=total_steps,
|
| 109 |
+
desc="Training Tokenizer (IWSLT-Local)",
|
| 110 |
+
),
|
| 111 |
+
trainer=trainer,
|
| 112 |
+
length=total_steps,
|
| 113 |
+
)
|
| 114 |
+
except KeyboardInterrupt:
|
| 115 |
+
print("\nTokenizer training interrupted by user.")
|
| 116 |
+
|
| 117 |
+
print("Tokenizer training complete.")
|
| 118 |
+
|
| 119 |
+
tokenizer.save(str(TOKENIZER_SAVE_PATH))
|
| 120 |
+
|
| 121 |
+
print(f"Tokenizer saved to: {TOKENIZER_SAVE_PATH}")
|
| 122 |
+
print(f"Total vocabulary size: {tokenizer.get_vocab_size()}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
# dataset = get_raw_data()
|
| 127 |
+
# print(type(dataset))
|
| 128 |
+
|
| 129 |
+
# tokenizer: Tokenizer = instantiate_tokenizer()
|
| 130 |
+
# tokenizer.save(str(TOKENIZER_SAVE_PATH))
|
| 131 |
+
|
| 132 |
+
train_tokenizer()
|
| 133 |
+
|
| 134 |
+
run = wandb.init(
|
| 135 |
+
entity="alaindelong-hcmut",
|
| 136 |
+
project="Attention Is All You Build",
|
| 137 |
+
job_type="tokenizer-train",
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# Log tokenizer
|
| 141 |
+
tokenizer_artifact = wandb.Artifact(
|
| 142 |
+
name="iwslt_en-vi_tokenizer",
|
| 143 |
+
type="tokenizer",
|
| 144 |
+
description="BPE Tokenizer trained on IWSLT 15 (133k+ pairs en-vi)",
|
| 145 |
+
metadata={
|
| 146 |
+
"vocab_size": 32000,
|
| 147 |
+
"algorithm": "BPE",
|
| 148 |
+
"framework": "huggingface",
|
| 149 |
+
"training_data": "iwslt-15-en-vi-133k",
|
| 150 |
+
"lower_case": False,
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
tokenizer_artifact.add_file(local_path=str(TOKENIZER_SAVE_PATH))
|
| 154 |
+
run.log_artifact(tokenizer_artifact, aliases=["baseline"])
|
| 155 |
+
|
| 156 |
+
run.finish()
|
src/utils.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
from datasets import DatasetDict, Dataset, load_dataset
|
| 7 |
+
import torch
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
|
| 10 |
+
from jaxtyping import Bool, Int
|
| 11 |
+
|
| 12 |
+
# from src import model
|
| 13 |
+
import model
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Utility function to set random seed for reproducibility
|
| 17 |
+
def seed_everything(seed: int = 42) -> None:
|
| 18 |
+
"""
|
| 19 |
+
Set random seed for Python, NumPy, and PyTorch to ensure reproducibility.
|
| 20 |
+
Args:
|
| 21 |
+
seed (int): The seed value to use.
|
| 22 |
+
"""
|
| 23 |
+
random.seed(seed)
|
| 24 |
+
np.random.seed(seed)
|
| 25 |
+
torch.manual_seed(seed)
|
| 26 |
+
torch.cuda.manual_seed(seed)
|
| 27 |
+
torch.cuda.manual_seed_all(seed)
|
| 28 |
+
torch.backends.cudnn.deterministic = True
|
| 29 |
+
torch.backends.cudnn.benchmark = False
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def make_run_name(model_name: str, d_model: int) -> str:
|
| 33 |
+
time_tag: str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 34 |
+
return f"{model_name}-{d_model}d-{time_tag}"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# --- Helper functions for cleaning ---
|
| 38 |
+
def is_valid_pair(example: dict) -> bool:
|
| 39 |
+
"""Check if both 'en' and 'vi' strings are non-empty."""
|
| 40 |
+
translation = example.get("translation", {})
|
| 41 |
+
en_text = translation.get("en", "").strip()
|
| 42 |
+
vi_text = translation.get("vi", "").strip()
|
| 43 |
+
return bool(en_text) and bool(vi_text) # (Return True if both are valid)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def filter_empty(dataset: Dataset, num_proc: int) -> Dataset:
|
| 47 |
+
"""
|
| 48 |
+
Applies the validation filter to a dataset split using
|
| 49 |
+
parallel processing (via .map() or .filter()).
|
| 50 |
+
"""
|
| 51 |
+
print(f" Filtering empty strings from split...")
|
| 52 |
+
# (We use .filter() which is highly optimized)
|
| 53 |
+
original_len = len(dataset)
|
| 54 |
+
|
| 55 |
+
filtered_dataset = dataset.filter(
|
| 56 |
+
is_valid_pair, num_proc=num_proc # (Use parallel processing from config)
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
new_len = len(filtered_dataset)
|
| 60 |
+
print(f" Filtered {original_len - new_len} empty/invalid pairs.")
|
| 61 |
+
return filtered_dataset
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# --- Dataset Loading & Splitting ---
|
| 65 |
+
def get_raw_data(
|
| 66 |
+
dataset_path: str | Path, for_tokenizer: bool = False, num_workers: int = 8
|
| 67 |
+
) -> Dataset | tuple[Dataset, Dataset, Dataset]:
|
| 68 |
+
"""
|
| 69 |
+
Load and filter dataset splits from a given path.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
dataset_path (str | Path): Path to the dataset directory or config.
|
| 73 |
+
for_tokenizer (bool): If True, return only filtered train split (for tokenizer training).
|
| 74 |
+
If False, return tuple of (train, validation, test) splits (for model training/eval).
|
| 75 |
+
num_workers (int): Number of workers for parallel filtering.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Dataset: Filtered train split (if for_tokenizer=True).
|
| 79 |
+
tuple(Dataset, Dataset, Dataset): Filtered train, validation, test splits (if for_tokenizer=False).
|
| 80 |
+
"""
|
| 81 |
+
print(f"Loading datasets from: {dataset_path}")
|
| 82 |
+
all_splits: DatasetDict = load_dataset(path=str(dataset_path))
|
| 83 |
+
print(all_splits)
|
| 84 |
+
|
| 85 |
+
print("--- Filtering Datasets (Removing empty sentences) ---")
|
| 86 |
+
train_data: Dataset = filter_empty(all_splits["train"], num_workers)
|
| 87 |
+
val_data: Dataset = filter_empty(all_splits["validation"], num_workers)
|
| 88 |
+
test_data: Dataset = filter_empty(all_splits["test"], num_workers)
|
| 89 |
+
|
| 90 |
+
if for_tokenizer:
|
| 91 |
+
return train_data
|
| 92 |
+
else:
|
| 93 |
+
return train_data, val_data, test_data
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Utility function to set random seed for reproducibility
|
| 97 |
+
def load_tokenizer(tokenizer_path: str | Path) -> PreTrainedTokenizerFast:
|
| 98 |
+
"""
|
| 99 |
+
Load a trained tokenizer from file and return tokenizer object and special token ids.
|
| 100 |
+
Args:
|
| 101 |
+
tokenizer_path (str | Path): Path to the tokenizer JSON file.
|
| 102 |
+
special_tokens (list[str], optional): List of special tokens to get ids for (e.g. ["[PAD]", "[SOS]", "[EOS]", "[UNK]"]).
|
| 103 |
+
Returns:
|
| 104 |
+
tokenizer (Tokenizer): Loaded tokenizer object.
|
| 105 |
+
token_ids (dict): Dictionary of special token ids.
|
| 106 |
+
"""
|
| 107 |
+
print(f"Loading tokenizer from {tokenizer_path}...")
|
| 108 |
+
# tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 109 |
+
tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
| 110 |
+
tokenizer.pad_token = "[PAD]"
|
| 111 |
+
tokenizer.unk_token = "[UNK]"
|
| 112 |
+
tokenizer.bos_token = "[SOS]" # bos = Beginning Of Sentence
|
| 113 |
+
tokenizer.eos_token = "[EOS]" # eos = End Of Sentence
|
| 114 |
+
return tokenizer
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def create_padding_mask(
|
| 118 |
+
input_ids: Int[Tensor, "B T_k"], pad_token_id: int
|
| 119 |
+
) -> Bool[Tensor, "B 1 1 T_k"]:
|
| 120 |
+
"""
|
| 121 |
+
Creates a padding mask for the attention mechanism.
|
| 122 |
+
|
| 123 |
+
This mask identifies positions holding the <PAD> token
|
| 124 |
+
and prepares a mask tensor that, when broadcasted, will mask
|
| 125 |
+
these positions in the attention scores matrix (B, H, T_q, T_k).
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
input_ids (Tensor): The input token IDs. Shape (B, T_k).
|
| 129 |
+
pad_token_id (int): The ID of the padding token.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
Tensor: A boolean mask of shape (B, 1, 1, T_k).
|
| 133 |
+
'True' means "keep" (not a pad token).
|
| 134 |
+
'False' means "mask out" (is a pad token).
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
# 1. Create the base mask
|
| 138 |
+
# (input_ids != pad_token_id) will be True for real tokens, False for PAD
|
| 139 |
+
# Shape: (B, T_k)
|
| 140 |
+
mask: Tensor = input_ids != pad_token_id
|
| 141 |
+
|
| 142 |
+
# 2. Add dimensions for broadcasting
|
| 143 |
+
# We add a dimension for T_q (dim 1) and H (dim 2)
|
| 144 |
+
# Shape: (B, T_k) -> (B, 1, T_k) -> (B, 1, 1, T_k)
|
| 145 |
+
return mask.unsqueeze(1).unsqueeze(2)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def create_look_ahead_mask(seq_len: int) -> Bool[Tensor, "1 1 T_q T_q"]:
|
| 149 |
+
"""
|
| 150 |
+
Creates a causal (look-ahead) mask for the Decoder's self-attention.
|
| 151 |
+
|
| 152 |
+
This mask prevents positions from attending to subsequent positions.
|
| 153 |
+
It's a square matrix where the upper triangle (future) is False
|
| 154 |
+
and the lower triangle (past/present) is True.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
seq_len (int): The sequence length (T_q).
|
| 158 |
+
device (torch.device): The device to create the tensor on (e.g., 'cuda').
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Tensor: A boolean mask of shape (1, 1, T_q, T_q).
|
| 162 |
+
'True' means "keep" (allowed to see).
|
| 163 |
+
'False' means "mask out" (future token).
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# 1. Create a square matrix of ones.
|
| 167 |
+
# Shape: (T_q, T_q)
|
| 168 |
+
ones = torch.ones(seq_len, seq_len)
|
| 169 |
+
|
| 170 |
+
# 2. Get the lower triangular part (bao gồm đường chéo)
|
| 171 |
+
# This sets the upper triangle (future) to 0 and keeps the rest 1.
|
| 172 |
+
# Shape: (T_q, T_q)
|
| 173 |
+
# Example (T_q=3):
|
| 174 |
+
# [[1., 0., 0.],
|
| 175 |
+
# [1., 1., 0.],
|
| 176 |
+
# [1., 1., 1.]]
|
| 177 |
+
lower_triangular: Tensor = torch.tril(ones)
|
| 178 |
+
|
| 179 |
+
# 3. Convert to boolean and add broadcasting dimensions
|
| 180 |
+
# Shape: (T_q, T_q) -> (1, 1, T_q, T_q)
|
| 181 |
+
# (mask == 1) converts 1. to True, 0. to False
|
| 182 |
+
return (lower_triangular == 1).unsqueeze(0).unsqueeze(0)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def greedy_decode_sentence(
|
| 186 |
+
model: model.Transformer,
|
| 187 |
+
src: Int[Tensor, "1 T_src"], # Input: one sentence
|
| 188 |
+
src_mask: Bool[Tensor, "1 1 1 T_src"],
|
| 189 |
+
max_len: int,
|
| 190 |
+
sos_token_id: int,
|
| 191 |
+
eos_token_id: int,
|
| 192 |
+
device: torch.device,
|
| 193 |
+
) -> Int[Tensor, "1 T_out"]:
|
| 194 |
+
"""
|
| 195 |
+
Performs greedy decoding for a single sentence.
|
| 196 |
+
This is an autoregressive process (token by token).
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
model: The trained Transformer model (already on device).
|
| 200 |
+
src: The source token IDs (e.g., English).
|
| 201 |
+
src_mask: The padding mask for the source.
|
| 202 |
+
max_len: The maximum length to generate.
|
| 203 |
+
sos_token_id: The ID for [SOS] token.
|
| 204 |
+
eos_token_id: The ID for [EOS] token.
|
| 205 |
+
device: The device to run on.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
Tensor: The generated target token IDs (e.g., Vietnamese).
|
| 209 |
+
"""
|
| 210 |
+
|
| 211 |
+
# Set model to eval mode (disables dropout)
|
| 212 |
+
model.eval()
|
| 213 |
+
|
| 214 |
+
# No gradients needed
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
|
| 217 |
+
# --- 1. Encode the source *once* ---
|
| 218 |
+
# (B, T_src) -> (B, T_src, D)
|
| 219 |
+
src_embedded = model.src_embed(src)
|
| 220 |
+
src_with_pos = model.pos_enc(src_embedded)
|
| 221 |
+
enc_output: Tensor = model.encoder(src_with_pos, src_mask)
|
| 222 |
+
|
| 223 |
+
# --- 2. Initialize the Decoder input ---
|
| 224 |
+
# Start with the [SOS] token. Shape: (1, 1)
|
| 225 |
+
decoder_input: Tensor = torch.tensor(
|
| 226 |
+
[[sos_token_id]], dtype=torch.long, device=device
|
| 227 |
+
) # Shape: (B=1, T_tgt=1)
|
| 228 |
+
|
| 229 |
+
# --- 3. Autoregressive Loop ---
|
| 230 |
+
for _ in range(max_len - 1): # (Max length - 1, since we have [SOS])
|
| 231 |
+
|
| 232 |
+
# --- a. Get Target Embedding + Position ---
|
| 233 |
+
# (B, T_tgt) -> (B, T_tgt, D)
|
| 234 |
+
tgt_embedded = model.tgt_embed(decoder_input)
|
| 235 |
+
tgt_with_pos = model.pos_enc(tgt_embedded)
|
| 236 |
+
|
| 237 |
+
# --- b. Create Target Mask (Causal) ---
|
| 238 |
+
# We must re-create the mask every loop,
|
| 239 |
+
# as T_tgt (decoder_input.size(1)) is growing.
|
| 240 |
+
# Shape: (1, 1, T_tgt, T_tgt)
|
| 241 |
+
T_tgt = decoder_input.size(1)
|
| 242 |
+
tgt_mask = create_look_ahead_mask(T_tgt).to(device)
|
| 243 |
+
|
| 244 |
+
# --- c. Run Decoder and Generator ---
|
| 245 |
+
# (B, T_tgt, D)
|
| 246 |
+
dec_output: Tensor = model.decoder(
|
| 247 |
+
tgt_with_pos, enc_output, src_mask, tgt_mask
|
| 248 |
+
)
|
| 249 |
+
# (B, T_tgt, vocab_size)
|
| 250 |
+
logits: Tensor = model.generator(dec_output)
|
| 251 |
+
|
| 252 |
+
# --- d. Get the *last* token's logits ---
|
| 253 |
+
# (B, T_tgt, vocab_size) -> (B, vocab_size)
|
| 254 |
+
last_token_logits = logits[:, -1, :]
|
| 255 |
+
|
| 256 |
+
# --- e. Greedy Search (get highest prob. token) ---
|
| 257 |
+
# (B, vocab_size) -> (B, 1)
|
| 258 |
+
next_token: Tensor = torch.argmax(last_token_logits, dim=-1).unsqueeze(-1)
|
| 259 |
+
|
| 260 |
+
# --- f. Append the new token ---
|
| 261 |
+
# (B, T_tgt) + (B, 1) -> (B, T_tgt + 1)
|
| 262 |
+
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| 263 |
+
|
| 264 |
+
# --- g. Check for [EOS] ---
|
| 265 |
+
# If the *last* token we added is [EOS], stop generating.
|
| 266 |
+
if next_token.item() == eos_token_id:
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
return decoder_input.squeeze(0) # Return shape (T_out)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def filter_and_detokenize(token_list: list[str], skip_special: bool = True) -> str:
|
| 273 |
+
"""
|
| 274 |
+
Manually joins tokens with a space and cleans up common
|
| 275 |
+
punctuation issues caused by whitespace tokenization.
|
| 276 |
+
"""
|
| 277 |
+
if skip_special:
|
| 278 |
+
# 1. Filter out special tokens
|
| 279 |
+
special_tokens = {"[PAD]", "[UNK]", "[SOS]", "[EOS]"}
|
| 280 |
+
token_list = [tok for tok in token_list if tok not in special_tokens]
|
| 281 |
+
|
| 282 |
+
# 2. Join with spaces
|
| 283 |
+
detokenized_string = " ".join(token_list)
|
| 284 |
+
|
| 285 |
+
# 3. Clean up punctuation
|
| 286 |
+
# (This is a simple heuristic-based detokenizer)
|
| 287 |
+
# Remove space before punctuation: "project ." -> "project."
|
| 288 |
+
detokenized_string = re.sub(r'\s([.,!?\'":;])', r"\1", detokenized_string)
|
| 289 |
+
# Handle contractions: "don 't" -> "don't"
|
| 290 |
+
detokenized_string = re.sub(r"(\w)\s(\'\w)", r"\1\2", detokenized_string)
|
| 291 |
+
|
| 292 |
+
return detokenized_string
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# Define a high-level, production-ready
|
| 296 |
+
# inference function that handles all steps.
|
| 297 |
+
def translate(
|
| 298 |
+
model: model.Transformer,
|
| 299 |
+
tokenizer: PreTrainedTokenizerFast,
|
| 300 |
+
sentence_en: str,
|
| 301 |
+
device: torch.device,
|
| 302 |
+
max_len: int,
|
| 303 |
+
sos_token_id: int,
|
| 304 |
+
eos_token_id: int,
|
| 305 |
+
pad_token_id: int,
|
| 306 |
+
) -> str:
|
| 307 |
+
"""
|
| 308 |
+
Translates a single English sentence to Vietnamese.
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
model: The trained Transformer model.
|
| 312 |
+
tokenizer: The (PreTrainedTokenizerFast) tokenizer.
|
| 313 |
+
sentence_en: The raw English input string.
|
| 314 |
+
device: The device to run on.
|
| 315 |
+
max_len: The max sequence length (from config).
|
| 316 |
+
sos_token_id: The ID for [SOS].
|
| 317 |
+
eos_token_id: The ID for [EOS].
|
| 318 |
+
pad_token_id: The ID for [PAD].
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
str: The translated Vietnamese string.
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
# Set model to evaluation mode
|
| 325 |
+
model.eval()
|
| 326 |
+
|
| 327 |
+
# Run inference in a no-gradient context
|
| 328 |
+
with torch.no_grad():
|
| 329 |
+
|
| 330 |
+
# 1. Tokenize the source (English) sentence
|
| 331 |
+
src_encoding = tokenizer(
|
| 332 |
+
sentence_en,
|
| 333 |
+
truncation=True,
|
| 334 |
+
max_length=max_len,
|
| 335 |
+
add_special_tokens=False, # (Encoder does not need SOS/EOS)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# 2. Convert to Tensor, add Batch dimension (B=1), and move to device
|
| 339 |
+
# Shape: (1, T_src)
|
| 340 |
+
src_ids: Tensor = torch.tensor(
|
| 341 |
+
[src_encoding["input_ids"]], dtype=torch.long
|
| 342 |
+
).to(device)
|
| 343 |
+
|
| 344 |
+
# 3. Create the source padding mask
|
| 345 |
+
# Shape: (1, 1, 1, T_src)
|
| 346 |
+
src_mask: Tensor = create_padding_mask(src_ids, pad_token_id).to(device)
|
| 347 |
+
|
| 348 |
+
# 4. Generate the target (Vietnamese) token IDs
|
| 349 |
+
# (This calls the autoregressive function from Cell 16A)
|
| 350 |
+
# Shape: (T_out)
|
| 351 |
+
predicted_ids: Tensor = greedy_decode_sentence(
|
| 352 |
+
model,
|
| 353 |
+
src_ids,
|
| 354 |
+
src_mask,
|
| 355 |
+
max_len=max_len,
|
| 356 |
+
sos_token_id=sos_token_id,
|
| 357 |
+
eos_token_id=eos_token_id,
|
| 358 |
+
device=device,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# 5. Detokenize (Fixing "sticky" words)
|
| 362 |
+
|
| 363 |
+
# Convert 1D GPU Tensor -> 1D CPU List
|
| 364 |
+
predicted_id_list = predicted_ids.cpu().tolist()
|
| 365 |
+
|
| 366 |
+
# This call is safe (1D List -> List[str])
|
| 367 |
+
predicted_token_list = tokenizer.convert_ids_to_tokens(predicted_id_list)
|
| 368 |
+
|
| 369 |
+
# Use our helper (from Cell 16B) to
|
| 370 |
+
# join with spaces, remove special tokens, and fix punctuation.
|
| 371 |
+
result_string = filter_and_detokenize(predicted_token_list, skip_special=True)
|
| 372 |
+
|
| 373 |
+
return result_string
|
| 374 |
+
|
| 375 |
+
print("Inference function `translate()` defined.")
|