Upload folder using huggingface_hub
Browse files- .gitignore +1 -0
- CODE_OF_CONDUCT.md +128 -0
- LICENSE +21 -0
- README.md +8 -3
- images/google.png +0 -0
- images/inference.png +0 -0
- pyproject.toml +5 -0
- src/__pycache__/config.cpython-39.pyc +0 -0
- src/__pycache__/dataset.cpython-39.pyc +0 -0
- src/__pycache__/model.cpython-39.pyc +0 -0
- src/__pycache__/train.cpython-39.pyc +0 -0
- src/config.py +24 -0
- src/dataset.py +111 -0
- src/inference.py +36 -0
- src/model.py +433 -0
- src/train.py +294 -0
- tokenizer_en.json +0 -0
- tokenizer_it.json +0 -0
- uv.lock +0 -0
- weights/tmodel_19.pt +3 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.venv
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, religion, or sexual identity
|
| 10 |
+
and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the
|
| 26 |
+
overall community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or
|
| 31 |
+
advances of any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email
|
| 35 |
+
address, without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
shubh622005@gmail.com.
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series
|
| 86 |
+
of actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or
|
| 93 |
+
permanent ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within
|
| 113 |
+
the community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.0, available at
|
| 119 |
+
https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by [Mozilla's code of conduct
|
| 122 |
+
enforcement ladder](https://github.com/mozilla/diversity).
|
| 123 |
+
|
| 124 |
+
[homepage]: https://www.contributor-covenant.org
|
| 125 |
+
|
| 126 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 127 |
+
https://www.contributor-covenant.org/faq. Translations are available at
|
| 128 |
+
https://www.contributor-covenant.org/translations.
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Shubham
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,3 +1,8 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Implementation of transformer architecture based on [Attention is All You Need](https://arxiv.org/abs/1706.03762)
|
| 2 |
+
|
| 3 |
+
### Inference outputs
|
| 4 |
+
|
| 5 |
+
<div align="center">
|
| 6 |
+
<img src="/images/inference.png" alt="Inference" width="400" height="300" style="margin: 0 10px;">
|
| 7 |
+
<img src="/images/google.png" alt="Google" width="400" height="300" style="margin: 0 10px;">
|
| 8 |
+
</div>
|
images/google.png
ADDED
|
images/inference.png
ADDED
|
pyproject.toml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "transformers"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
requires-python = ">=3.9"
|
| 5 |
+
dependencies = ["torch==2.8.0", "datasets==4.0.0", "tokenizers==0.22.0"]
|
src/__pycache__/config.cpython-39.pyc
ADDED
|
Binary file (789 Bytes). View file
|
|
|
src/__pycache__/dataset.cpython-39.pyc
ADDED
|
Binary file (2.72 kB). View file
|
|
|
src/__pycache__/model.cpython-39.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
src/__pycache__/train.cpython-39.pyc
ADDED
|
Binary file (6.4 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_config():
|
| 5 |
+
return {
|
| 6 |
+
"batch_size": 8,
|
| 7 |
+
"num_epochs": 20,
|
| 8 |
+
"lr": 10**-4,
|
| 9 |
+
"seq_len": 350,
|
| 10 |
+
"d_model": 512,
|
| 11 |
+
"lang_src": "en",
|
| 12 |
+
"lang_target": "it",
|
| 13 |
+
"model_folder": "weights",
|
| 14 |
+
"model_basename": "tmodel_",
|
| 15 |
+
"preload": None,
|
| 16 |
+
"tokenizer_file": "tokenizer_{0}.json",
|
| 17 |
+
"experiment_name": "runs/tmodel",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_weights_file_path(config, epoch: str):
|
| 22 |
+
model_folder = config["model_folder"]
|
| 23 |
+
model_filename = f"{config['model_basename']}{epoch}.pt"
|
| 24 |
+
return str(Path(".") / model_folder / model_filename)
|
src/dataset.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class BilingualDataset(Dataset):
|
| 6 |
+
def __init__(
|
| 7 |
+
self, dataset, tokenizer_src, tokenizer_target, src_lang, target_lang, seq_len
|
| 8 |
+
):
|
| 9 |
+
"""
|
| 10 |
+
Initializes a new instance of this Dataset. One language pair of the dataset
|
| 11 |
+
https://huggingface.co/datasets/Helsinki-NLP/opus_books
|
| 12 |
+
"""
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.seq_len = seq_len
|
| 15 |
+
self.src_lang = src_lang
|
| 16 |
+
self.tokenizer_target = tokenizer_target
|
| 17 |
+
self.tokenizer_src = tokenizer_src
|
| 18 |
+
self.target_lang = target_lang
|
| 19 |
+
self.dataset = dataset
|
| 20 |
+
|
| 21 |
+
self.start_of_sentence_token = torch.tensor(
|
| 22 |
+
[tokenizer_target.token_to_id("[SOS]")], dtype=torch.int64
|
| 23 |
+
)
|
| 24 |
+
self.end_of_sentence_token = torch.tensor(
|
| 25 |
+
[tokenizer_target.token_to_id("[EOS]")], dtype=torch.int64
|
| 26 |
+
)
|
| 27 |
+
self.padding_token = torch.tensor(
|
| 28 |
+
[tokenizer_target.token_to_id("[PAD]")], dtype=torch.int64
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return len(self.dataset)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, index):
|
| 35 |
+
"""
|
| 36 |
+
This function takes the text of the sentence from the dataset, tokenizes it using the
|
| 37 |
+
tokenizer_src and the tokenizer_target respectively and constructs the tensors used to pass to the transformer
|
| 38 |
+
"""
|
| 39 |
+
src_target_pair = self.dataset[index]
|
| 40 |
+
src_text = src_target_pair["translation"][self.src_lang]
|
| 41 |
+
target_text = src_target_pair["translation"][self.target_lang]
|
| 42 |
+
|
| 43 |
+
encoder_input_tokens = self.tokenizer_src.encode(src_text).ids
|
| 44 |
+
decoder_input_tokens = self.tokenizer_target.encode(target_text).ids
|
| 45 |
+
|
| 46 |
+
enc_num_padding_tokens = self.seq_len - len(encoder_input_tokens) - 2
|
| 47 |
+
dec_num_padding_tokens = self.seq_len - len(decoder_input_tokens) - 1
|
| 48 |
+
|
| 49 |
+
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
|
| 50 |
+
raise ValueError("Sentence is too long")
|
| 51 |
+
|
| 52 |
+
encoder_input = torch.cat(
|
| 53 |
+
[
|
| 54 |
+
self.start_of_sentence_token,
|
| 55 |
+
torch.tensor(encoder_input_tokens, dtype=torch.int64),
|
| 56 |
+
self.end_of_sentence_token,
|
| 57 |
+
torch.tensor(
|
| 58 |
+
[self.padding_token] * enc_num_padding_tokens, dtype=torch.int64
|
| 59 |
+
),
|
| 60 |
+
],
|
| 61 |
+
dim=0,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
decoder_input = torch.cat(
|
| 65 |
+
[
|
| 66 |
+
self.start_of_sentence_token,
|
| 67 |
+
torch.tensor(decoder_input_tokens, dtype=torch.int64),
|
| 68 |
+
torch.tensor(
|
| 69 |
+
[self.padding_token] * dec_num_padding_tokens, dtype=torch.int64
|
| 70 |
+
),
|
| 71 |
+
],
|
| 72 |
+
dim=0,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
label = torch.cat(
|
| 76 |
+
[
|
| 77 |
+
torch.tensor(decoder_input_tokens, dtype=torch.int64),
|
| 78 |
+
self.end_of_sentence_token,
|
| 79 |
+
torch.tensor(
|
| 80 |
+
[self.padding_token] * dec_num_padding_tokens, dtype=torch.int64
|
| 81 |
+
),
|
| 82 |
+
],
|
| 83 |
+
dim=0,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
assert encoder_input.size(0) == self.seq_len
|
| 87 |
+
assert decoder_input.size(0) == self.seq_len
|
| 88 |
+
assert label.size(0) == self.seq_len
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"encoder_input": encoder_input, # (seq_len)
|
| 92 |
+
"decoder_input": decoder_input, # (seq_len)
|
| 93 |
+
"encoder_mask": (encoder_input != self.padding_token)
|
| 94 |
+
.unsqueeze(0)
|
| 95 |
+
.unsqueeze(0)
|
| 96 |
+
.int(), # (1, 1, seq_len) adding the sequence dimension and batch dimension
|
| 97 |
+
"decoder_mask": (decoder_input != self.padding_token).unsqueeze(0).int()
|
| 98 |
+
& causal_mask(
|
| 99 |
+
decoder_input.size(0)
|
| 100 |
+
), # (1, seq_len) & (1, seq_len, seq_len),
|
| 101 |
+
"label": label, # (seq_len)
|
| 102 |
+
"src_text": src_text,
|
| 103 |
+
"tgt_text": target_text,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def causal_mask(size):
|
| 108 |
+
# This returns everything above the diagonal. Hence we reverse it by mask == 0 in return as we need
|
| 109 |
+
# stuff below the diagonal
|
| 110 |
+
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
|
| 111 |
+
return mask == 0
|
src/inference.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from train import get_model, greedy_decode, get_or_build_tokenizer
|
| 3 |
+
from config import get_config
|
| 4 |
+
|
| 5 |
+
INPUT_TEXT = "sun rises in the night"
|
| 6 |
+
|
| 7 |
+
def inference():
|
| 8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 9 |
+
device = torch.device(device)
|
| 10 |
+
|
| 11 |
+
config = get_config()
|
| 12 |
+
|
| 13 |
+
tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"])
|
| 14 |
+
tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_target"])
|
| 15 |
+
|
| 16 |
+
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
|
| 17 |
+
|
| 18 |
+
model_filename = "weights/tmodel_19.pt"
|
| 19 |
+
state = torch.load(model_filename, map_location=device)
|
| 20 |
+
model.load_state_dict(state["model_state_dict"])
|
| 21 |
+
model.eval()
|
| 22 |
+
|
| 23 |
+
tokens = tokenizer_src.encode(INPUT_TEXT).ids
|
| 24 |
+
tokens = [tokenizer_src.token_to_id("[SOS]")] + tokens + [tokenizer_src.token_to_id("[EOS]")]
|
| 25 |
+
encoder_input = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)
|
| 26 |
+
encoder_mask = (encoder_input != tokenizer_src.token_to_id("[PAD]")).unsqueeze(0).unsqueeze(0).to(device)
|
| 27 |
+
|
| 28 |
+
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, config["seq_len"], device)
|
| 29 |
+
output_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
|
| 30 |
+
|
| 31 |
+
print("Source:", INPUT_TEXT)
|
| 32 |
+
print("Predicted:", output_text)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
inference()
|
src/model.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# In pytorch, forward function of each class is called automatically, so we do not need to call it each time we call that class.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class InputEmbeddings(nn.Module):
|
| 9 |
+
def __init__(self, d_model: int, vocab_size: int) -> None:
|
| 10 |
+
"""
|
| 11 |
+
vocab_size: number of words in the vocabulary
|
| 12 |
+
d_model: dimension of the model
|
| 13 |
+
1. Creates a embedding of size d_model for each word in the vocab
|
| 14 |
+
"""
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.d_model = d_model
|
| 17 |
+
self.vocab_size = vocab_size
|
| 18 |
+
self.embeddings = nn.Embedding(vocab_size, d_model)
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
"""
|
| 22 |
+
x: (batch_size, seq_len)
|
| 23 |
+
return: (batch_size, seq_len, d_model)
|
| 24 |
+
Convert the input words to their corresponding embeddings
|
| 25 |
+
"""
|
| 26 |
+
# multiplying by sqrt(self.d_model) to scale the embeddings
|
| 27 |
+
return self.embeddings(x) * math.sqrt(self.d_model)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PositionalEncoding(nn.Module):
|
| 31 |
+
def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
|
| 32 |
+
"""
|
| 33 |
+
seq_len: maximum length of the input sentence
|
| 34 |
+
d_modal: dimension of the model
|
| 35 |
+
dropout: dropout rate
|
| 36 |
+
1. Create a matrix of shape (seq_len, d_model) with all values set to 0
|
| 37 |
+
2. Create a position vector of shape (seq_len, 1) with values from 0 to seq_len-1
|
| 38 |
+
3. Create a denominator vector of shape (d_model/2) with values from 0 to d_model/2-1
|
| 39 |
+
and apply the formula: exp(-log(10000) * (2i/d_model))
|
| 40 |
+
4. Apply the sine function to the even indices of the positional encoding matrix
|
| 41 |
+
and the cosine function to the odd indices
|
| 42 |
+
5. Add a batch dimension to the positional encoding matrix and register it as a buffer
|
| 43 |
+
"""
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.d_model = d_model
|
| 46 |
+
self.seq_len = seq_len
|
| 47 |
+
# dropout prevents overfitting of the model, randomly zeroes some values
|
| 48 |
+
self.dropout = nn.Dropout(dropout)
|
| 49 |
+
|
| 50 |
+
positional_encoding = torch.zeros(seq_len, d_model) # (seq_len, d_model)
|
| 51 |
+
position_vector = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
|
| 52 |
+
1
|
| 53 |
+
) # (seq_len, 1)
|
| 54 |
+
denominator = torch.exp(
|
| 55 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10_000.0) / d_model)
|
| 56 |
+
) # (d_model/2, )
|
| 57 |
+
|
| 58 |
+
positional_encoding[:, 0::2] = torch.sin(position_vector * denominator)
|
| 59 |
+
positional_encoding[:, 1::2] = torch.cos(position_vector * denominator)
|
| 60 |
+
|
| 61 |
+
# we unsqueeze to make it broadcastable over batch dimension (batch_size, seq_len, d_model) + (1, seq_len, d_model)
|
| 62 |
+
positional_encoding = positional_encoding.unsqueeze(0) # (1, seq_len, d_model)
|
| 63 |
+
self.register_buffer("positional_encoding", positional_encoding)
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
"""
|
| 67 |
+
x: (batch_size, seq_len, d_model)
|
| 68 |
+
return: (batch_size, seq_len, d_model)
|
| 69 |
+
Add positional encoding to the input embeddings
|
| 70 |
+
"""
|
| 71 |
+
x = x + (self.positional_encoding[:, : x.shape[1], :]).requires_grad_(False)
|
| 72 |
+
return self.dropout(x)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class LayerNormalization(nn.Module):
|
| 76 |
+
def __init__(self, features: int, epsilon: float = 10**-6) -> None:
|
| 77 |
+
"""
|
| 78 |
+
features: number of features for which we have to perform layer normalization, i.e, d_model
|
| 79 |
+
epsilon: a very small number to prevent division by a very small number or 0
|
| 80 |
+
"""
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.epsilon = epsilon
|
| 83 |
+
|
| 84 |
+
self.alpha = nn.Parameter(torch.ones(features))
|
| 85 |
+
self.beta = nn.Parameter(torch.zeros(features))
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
"""
|
| 89 |
+
x: (batch_size, seq_len, features)
|
| 90 |
+
return: (batch_size, seq_len, features)
|
| 91 |
+
Implements the layer normalization formula
|
| 92 |
+
"""
|
| 93 |
+
mean = x.mean(dim=-1, keepdim=True)
|
| 94 |
+
std = x.std(dim=-1, keepdim=True)
|
| 95 |
+
return self.alpha * (x - mean) / (std + self.epsilon) + self.beta
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class FeedForwardBlock(nn.Module):
|
| 99 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
|
| 100 |
+
"""
|
| 101 |
+
d_model: dimension of the model. It would be the input dimension of the input layer of our feed forward network.
|
| 102 |
+
d_ff: dimensions of the hidden layer. It is usually larger than the input dimensions i.e. d_model
|
| 103 |
+
|
| 104 |
+
Architecture:
|
| 105 |
+
Input (batch_size, seq_len, d_model)
|
| 106 |
+
-> Linear(d_model → d_ff)
|
| 107 |
+
-> ReLU (non-linearity)
|
| 108 |
+
-> Dropout
|
| 109 |
+
-> Linear(d_ff → d_mudrodip?tab=overview&from=2025-08-01&to=2025-08-29odel)
|
| 110 |
+
Output (batch_size, seq_len, d_model)
|
| 111 |
+
"""
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.layer_1 = nn.Linear(d_model, d_ff)
|
| 114 |
+
self.dropout = nn.Dropout(dropout)
|
| 115 |
+
self.layer_2 = nn.Linear(d_ff, d_model)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
return self.layer_2(self.dropout(torch.relu(self.layer_1(x))))
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class MultiHeadAttentionBlock(nn.Module):
|
| 122 |
+
def __init__(self, d_model: int, head: int, dropout: float) -> None:
|
| 123 |
+
"""
|
| 124 |
+
d_model: dimension of the model.
|
| 125 |
+
head: number of parts we have to break the multihead attention block into
|
| 126 |
+
Initialize four linear layers of size d_model by d_model which we will use later
|
| 127 |
+
"""
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.d_model = d_model
|
| 130 |
+
self.heads = head
|
| 131 |
+
assert d_model % head == 0, "Head should completely divide the model dimensions"
|
| 132 |
+
|
| 133 |
+
self.d_k = d_model // head
|
| 134 |
+
self.w_q = nn.Linear(d_model, d_model)
|
| 135 |
+
self.w_k = nn.Linear(d_model, d_model)
|
| 136 |
+
self.w_v = nn.Linear(d_model, d_model)
|
| 137 |
+
|
| 138 |
+
self.w_o = nn.Linear(d_model, d_model)
|
| 139 |
+
self.dropout = nn.Dropout(dropout)
|
| 140 |
+
|
| 141 |
+
@staticmethod
|
| 142 |
+
def attention(query, key, value, mask, dropout: nn.Dropout):
|
| 143 |
+
"""
|
| 144 |
+
query, key and value are the input matrices to calculate the attention
|
| 145 |
+
mask is used in a case where we need to ignore the interactions between certain values.
|
| 146 |
+
For eg. While using this in a decoder, we would mask all the keys ahead of the word.
|
| 147 |
+
Similarly, we will ignore all the padded elements in a sentence.
|
| 148 |
+
|
| 149 |
+
This function implements the the attention calculation logic.
|
| 150 |
+
"""
|
| 151 |
+
d_k = query.shape[-1]
|
| 152 |
+
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(
|
| 153 |
+
d_k
|
| 154 |
+
) # "@" represents matrix multiplication in pytorch
|
| 155 |
+
|
| 156 |
+
if mask is not None:
|
| 157 |
+
attention_scores.masked_fill_(mask == 0, float("-inf"))
|
| 158 |
+
attention_scores = attention_scores.softmax(dim=-1)
|
| 159 |
+
|
| 160 |
+
if dropout is not None:
|
| 161 |
+
attention_scores = dropout(attention_scores)
|
| 162 |
+
|
| 163 |
+
return (attention_scores @ value), attention_scores
|
| 164 |
+
|
| 165 |
+
def forward(self, query, key, value, mask):
|
| 166 |
+
query = self.w_q(query)
|
| 167 |
+
key = self.w_k(key)
|
| 168 |
+
value = self.w_v(value)
|
| 169 |
+
|
| 170 |
+
# We now divide the matrices in `heads` part.
|
| 171 |
+
# (batch_size, seq_len, d_model) --> (batch_size, seq_len, head, (d_model // head)) --> (batch_size, head, seq_len, (d_model // head))
|
| 172 |
+
query = query.view(
|
| 173 |
+
query.shape[0], query.shape[1], self.heads, self.d_k
|
| 174 |
+
).transpose(1, 2)
|
| 175 |
+
key = key.view(key.shape[0], key.shape[1], self.heads, self.d_k).transpose(1, 2)
|
| 176 |
+
value = value.view(
|
| 177 |
+
value.shape[0], value.shape[1], self.heads, self.d_k
|
| 178 |
+
).transpose(1, 2)
|
| 179 |
+
|
| 180 |
+
# Calculate the attention values and the final output after multiplying it with `value`
|
| 181 |
+
x, self.attention_scores = MultiHeadAttentionBlock.attention(
|
| 182 |
+
query, key, value, mask, self.dropout
|
| 183 |
+
)
|
| 184 |
+
# (batch_size, head, seq_len, (d_model // head)) --> (batch_size, seq_len, head, (d_model // head)) --> (batch_size, seq_len, d_model)
|
| 185 |
+
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.heads * self.d_k)
|
| 186 |
+
|
| 187 |
+
return self.w_o(x)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class ResidualConnection(nn.Module):
|
| 191 |
+
def __init__(self, features: int, dropout: float) -> None:
|
| 192 |
+
"""
|
| 193 |
+
This class is basically a wrapper around all the blocks that we'll use in the transformer.
|
| 194 |
+
It will pass through that layer and automatically apply dropout and layer normalization to prevent values to go out of bound.
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
[LayerNorm -> Sublayer -> Dropout] + Input
|
| 198 |
+
"""
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.dropout = nn.Dropout(dropout)
|
| 201 |
+
self.norm = LayerNormalization(features=features)
|
| 202 |
+
|
| 203 |
+
def forward(self, x, sublayer):
|
| 204 |
+
return x + self.dropout(sublayer(self.norm(x)))
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class EncoderBlock(nn.Module):
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
features: int,
|
| 211 |
+
self_attention_block: MultiHeadAttentionBlock,
|
| 212 |
+
feed_forward_block: FeedForwardBlock,
|
| 213 |
+
dropout: float,
|
| 214 |
+
) -> None:
|
| 215 |
+
"""
|
| 216 |
+
This defines the structure of the encoder block.
|
| 217 |
+
First is the multihead self attention block and the second is the feed forward block
|
| 218 |
+
"""
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.self_attention_block = self_attention_block
|
| 221 |
+
self.feed_forward_block = feed_forward_block
|
| 222 |
+
self.dropout = dropout
|
| 223 |
+
self.residual_connections = nn.ModuleList(
|
| 224 |
+
[ResidualConnection(features, dropout) for _ in range(2)]
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
def forward(self, x, src_mask):
|
| 228 |
+
x = self.residual_connections[0](
|
| 229 |
+
x, lambda x: self.self_attention_block(x, x, x, src_mask)
|
| 230 |
+
)
|
| 231 |
+
x = self.residual_connections[1](x, self.feed_forward_block)
|
| 232 |
+
return x
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
class Encoder(nn.Module):
|
| 236 |
+
def __init__(self, features: int, layers: nn.ModuleList) -> None:
|
| 237 |
+
"""
|
| 238 |
+
This is the main Encoder class built up of multiple "EncoderBlock" classes
|
| 239 |
+
"""
|
| 240 |
+
super().__init__()
|
| 241 |
+
self.layers = layers
|
| 242 |
+
self.norm = LayerNormalization(features=features)
|
| 243 |
+
|
| 244 |
+
def forward(self, x, mask):
|
| 245 |
+
for layer in self.layers:
|
| 246 |
+
x = layer(x, mask)
|
| 247 |
+
return self.norm(x)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class DecoderBlock(nn.Module):
|
| 251 |
+
def __init__(
|
| 252 |
+
self,
|
| 253 |
+
self_attention_block: MultiHeadAttentionBlock,
|
| 254 |
+
cross_attention_block: MultiHeadAttentionBlock,
|
| 255 |
+
feed_forward_layer: FeedForwardBlock,
|
| 256 |
+
features: int,
|
| 257 |
+
dropout: float,
|
| 258 |
+
) -> None:
|
| 259 |
+
"""
|
| 260 |
+
This class defines the structure of the decoder block.
|
| 261 |
+
First is the masked multihead self attention layer which takes in the target embeddings,
|
| 262 |
+
Second is the cross multihead attention layer which takes query from the decoder but key and value from the encoder
|
| 263 |
+
Thirdly the feed forward layer that takes the output of the cross multi head attention
|
| 264 |
+
"""
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.self_attention_block = self_attention_block
|
| 267 |
+
self.cross_attention_block = cross_attention_block
|
| 268 |
+
self.feed_forward_layer = feed_forward_layer
|
| 269 |
+
self.residual_connections = nn.ModuleList(
|
| 270 |
+
[ResidualConnection(features, dropout) for _ in range(3)]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def forward(self, x, encoder_output, target_mask, src_mask):
|
| 274 |
+
x = self.residual_connections[0](
|
| 275 |
+
x, lambda x: self.self_attention_block(x, x, x, target_mask)
|
| 276 |
+
)
|
| 277 |
+
x = self.residual_connections[1](
|
| 278 |
+
x,
|
| 279 |
+
lambda x: self.cross_attention_block(
|
| 280 |
+
x, encoder_output, encoder_output, src_mask
|
| 281 |
+
),
|
| 282 |
+
)
|
| 283 |
+
x = self.residual_connections[2](x, self.feed_forward_layer)
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class Decoder(nn.Module):
|
| 288 |
+
def __init__(self, layers: nn.ModuleList, features: int) -> None:
|
| 289 |
+
"""
|
| 290 |
+
This is the main "Decoder" class built up of multiple "DecoderBlock" classes
|
| 291 |
+
"""
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.layers = layers
|
| 294 |
+
self.norm = LayerNormalization(features=features)
|
| 295 |
+
|
| 296 |
+
def forward(self, x, encoder_output, target_mask, src_mask):
|
| 297 |
+
for layer in self.layers:
|
| 298 |
+
x = layer(x, encoder_output, target_mask, src_mask)
|
| 299 |
+
return self.norm(x)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class ProjectionLayer(nn.Module):
|
| 303 |
+
def __init__(self, d_model: int, vocab_size: int):
|
| 304 |
+
"""
|
| 305 |
+
The output of the decoder block is passed through a linear layer and then a softmax to convert the vector embedding back to vocabulary
|
| 306 |
+
"""
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.proj = nn.Linear(d_model, vocab_size)
|
| 309 |
+
|
| 310 |
+
def forward(self, x):
|
| 311 |
+
return torch.log_softmax(self.proj(x), dim=-1)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class Transformer(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
encoder: Encoder,
|
| 318 |
+
decoder: Decoder,
|
| 319 |
+
src_embedding: InputEmbeddings,
|
| 320 |
+
target_embedding: InputEmbeddings,
|
| 321 |
+
src_position: PositionalEncoding,
|
| 322 |
+
target_position: PositionalEncoding,
|
| 323 |
+
projection_layer: ProjectionLayer,
|
| 324 |
+
) -> None:
|
| 325 |
+
"""
|
| 326 |
+
This is the main transformer class that encompasses the encoder, decoder and the projection layer.
|
| 327 |
+
"""
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.encoder = encoder
|
| 330 |
+
self.decoder = decoder
|
| 331 |
+
self.src_embedding = src_embedding
|
| 332 |
+
self.target_embedding = target_embedding
|
| 333 |
+
self.src_position = src_position
|
| 334 |
+
self.target_position = target_position
|
| 335 |
+
self.projection_layer = projection_layer
|
| 336 |
+
|
| 337 |
+
def encode(self, src, src_mask):
|
| 338 |
+
src = self.src_embedding(src)
|
| 339 |
+
src = self.src_position(src)
|
| 340 |
+
return self.encoder(src, src_mask)
|
| 341 |
+
|
| 342 |
+
def decode(self, encoder_output, src_mask, target, target_mask):
|
| 343 |
+
target = self.target_embedding(target)
|
| 344 |
+
target = self.target_position(target)
|
| 345 |
+
return self.decoder(target, encoder_output, target_mask, src_mask)
|
| 346 |
+
|
| 347 |
+
def projection(self, x):
|
| 348 |
+
return self.projection_layer(x)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def build_transformer(
|
| 352 |
+
src_vocab_size: int,
|
| 353 |
+
target_vocab_size: int,
|
| 354 |
+
src_seq_len: int,
|
| 355 |
+
target_seq_len: int,
|
| 356 |
+
d_model: int = 512,
|
| 357 |
+
N: int = 6,
|
| 358 |
+
head: int = 8,
|
| 359 |
+
dropout: float = 0.1,
|
| 360 |
+
d_ff: int = 2048,
|
| 361 |
+
) -> Transformer:
|
| 362 |
+
"""
|
| 363 |
+
src_vocab_size: number of words in the vocab
|
| 364 |
+
target_vocab_size: its the output of the target vocab
|
| 365 |
+
src_seq_len: it represents the maximum number of words in a sentence
|
| 366 |
+
target_seq_len: it represents the maximum number of words in a target sentence, usually equal to src_seq_len
|
| 367 |
+
d_model: It is the size of the model i.e the size of the embedding vector
|
| 368 |
+
N: Number of times the encoder/decoder blocks are repeated in an architecture
|
| 369 |
+
head: Number of splits to make in a in multihead attention
|
| 370 |
+
dropout: dropout after each step
|
| 371 |
+
d_ff: neurons in the inner layer of the linear layer
|
| 372 |
+
"""
|
| 373 |
+
src_embeddings = InputEmbeddings(d_model, src_vocab_size)
|
| 374 |
+
target_embeddings = InputEmbeddings(d_model, target_vocab_size)
|
| 375 |
+
|
| 376 |
+
src_positional_embeddings = PositionalEncoding(d_model, src_seq_len, dropout)
|
| 377 |
+
target_postional_embeddings = PositionalEncoding(d_model, target_seq_len, dropout)
|
| 378 |
+
|
| 379 |
+
encoder_blocks = []
|
| 380 |
+
for i in range(N):
|
| 381 |
+
encoder_self_multi_head_attention_block = MultiHeadAttentionBlock(
|
| 382 |
+
d_model, head, dropout
|
| 383 |
+
)
|
| 384 |
+
feed_forward_layer = FeedForwardBlock(d_model, d_ff, dropout)
|
| 385 |
+
encoder_blocks.append(
|
| 386 |
+
EncoderBlock(
|
| 387 |
+
d_model,
|
| 388 |
+
encoder_self_multi_head_attention_block,
|
| 389 |
+
feed_forward_layer,
|
| 390 |
+
dropout,
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
decoder_blocks = []
|
| 395 |
+
for i in range(N):
|
| 396 |
+
decoder_masked_multi_head_attention_block = MultiHeadAttentionBlock(
|
| 397 |
+
d_model, head, dropout
|
| 398 |
+
)
|
| 399 |
+
cross_multihead_attention_block = MultiHeadAttentionBlock(
|
| 400 |
+
d_model, head, dropout
|
| 401 |
+
)
|
| 402 |
+
feed_forward_layer = FeedForwardBlock(d_model, d_ff, dropout)
|
| 403 |
+
decoder_blocks.append(
|
| 404 |
+
DecoderBlock(
|
| 405 |
+
decoder_masked_multi_head_attention_block,
|
| 406 |
+
cross_multihead_attention_block,
|
| 407 |
+
feed_forward_layer,
|
| 408 |
+
d_model,
|
| 409 |
+
dropout,
|
| 410 |
+
)
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
|
| 414 |
+
decoder = Decoder(nn.ModuleList(decoder_blocks), d_model)
|
| 415 |
+
|
| 416 |
+
projection_layer = ProjectionLayer(d_model, target_vocab_size)
|
| 417 |
+
|
| 418 |
+
transformer = Transformer(
|
| 419 |
+
encoder,
|
| 420 |
+
decoder,
|
| 421 |
+
src_embeddings,
|
| 422 |
+
target_embeddings,
|
| 423 |
+
src_positional_embeddings,
|
| 424 |
+
target_postional_embeddings,
|
| 425 |
+
projection_layer,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# This is to initialize the values of the vector embeddings with sensible defaults
|
| 429 |
+
for p in transformer.parameters():
|
| 430 |
+
if p.dim() > 1:
|
| 431 |
+
nn.init.xavier_uniform_(p)
|
| 432 |
+
|
| 433 |
+
return transformer
|
src/train.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from config import get_config, get_weights_file_path
|
| 4 |
+
from torch.utils.data import random_split, DataLoader
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from tokenizers import Tokenizer
|
| 7 |
+
from dataset import BilingualDataset, causal_mask
|
| 8 |
+
from tokenizers.models import WordLevel
|
| 9 |
+
from tokenizers.trainers import WordLevelTrainer
|
| 10 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from model import build_transformer, Transformer
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def greedy_decode(
|
| 18 |
+
model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device
|
| 19 |
+
):
|
| 20 |
+
"""
|
| 21 |
+
Inference -
|
| 22 |
+
Start with just SOS token in target
|
| 23 |
+
Every iteration gives us a new next word which we concatenate into the decoder input and rerun the cycle
|
| 24 |
+
Loop till we get EOS
|
| 25 |
+
"""
|
| 26 |
+
sos_idx = tokenizer_tgt.token_to_id("[SOS]")
|
| 27 |
+
eos_idx = tokenizer_tgt.token_to_id("[EOS]")
|
| 28 |
+
|
| 29 |
+
# Just calculate the encoder input once
|
| 30 |
+
encoder_output = model.encode(source, source_mask)
|
| 31 |
+
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
|
| 32 |
+
while True:
|
| 33 |
+
if decoder_input.size(1) == max_len:
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
# run causal_mask
|
| 37 |
+
decoder_mask = (
|
| 38 |
+
causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
| 42 |
+
|
| 43 |
+
prob = model.projection(out[:, -1])
|
| 44 |
+
_, next_word = torch.max(prob, dim=1)
|
| 45 |
+
decoder_input = torch.cat(
|
| 46 |
+
[
|
| 47 |
+
decoder_input,
|
| 48 |
+
torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device),
|
| 49 |
+
],
|
| 50 |
+
dim=1,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if next_word == eos_idx:
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
return decoder_input.squeeze(0)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run_validation(
|
| 60 |
+
model,
|
| 61 |
+
validation_dataset,
|
| 62 |
+
tokenizer_src,
|
| 63 |
+
tokenizer_target,
|
| 64 |
+
max_len,
|
| 65 |
+
device,
|
| 66 |
+
print_msg,
|
| 67 |
+
num_examples=2,
|
| 68 |
+
):
|
| 69 |
+
model.eval()
|
| 70 |
+
count = 0
|
| 71 |
+
|
| 72 |
+
console_width = 80
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for batch in validation_dataset:
|
| 75 |
+
count += 1
|
| 76 |
+
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
|
| 77 |
+
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
|
| 78 |
+
|
| 79 |
+
# check that the batch size is 1
|
| 80 |
+
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
|
| 81 |
+
|
| 82 |
+
model_out = greedy_decode(
|
| 83 |
+
model,
|
| 84 |
+
encoder_input,
|
| 85 |
+
encoder_mask,
|
| 86 |
+
tokenizer_src,
|
| 87 |
+
tokenizer_target,
|
| 88 |
+
max_len,
|
| 89 |
+
device,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
source_text = batch["src_text"][0]
|
| 93 |
+
target_text = batch["tgt_text"][0]
|
| 94 |
+
model_out_text = tokenizer_target.decode(model_out.detach().cpu().numpy())
|
| 95 |
+
|
| 96 |
+
print_msg("-" * console_width)
|
| 97 |
+
print_msg(f"{'SOURCE: ':>12}{source_text}")
|
| 98 |
+
print_msg(f"{'TARGET: ':>12}{target_text}")
|
| 99 |
+
print_msg(f"{'PREDICTED: ':>12}{model_out_text}")
|
| 100 |
+
|
| 101 |
+
if count == num_examples:
|
| 102 |
+
print_msg("-" * console_width)
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_all_sentences(dataset, lang):
|
| 107 |
+
for item in dataset:
|
| 108 |
+
yield item["translation"][lang]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_or_build_tokenizer(config, dataset, lang):
|
| 112 |
+
"""
|
| 113 |
+
This takes in the dataset and splits all the sentences into tokens
|
| 114 |
+
Adds four extra tokens to the token list -> "[UNK]", "[SOS]", "[EOS]" and "[PAD]"
|
| 115 |
+
min frequency for each word to be in our tokenizer is 2 i.e. each word should appear alteast 2 times
|
| 116 |
+
to be included
|
| 117 |
+
"""
|
| 118 |
+
tokenizer_path = Path(config["tokenizer_file"].format(lang))
|
| 119 |
+
if not Path.exists(tokenizer_path):
|
| 120 |
+
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
|
| 121 |
+
tokenizer.pre_tokenizer = Whitespace()
|
| 122 |
+
trainer = WordLevelTrainer(
|
| 123 |
+
special_tokens=["[UNK]", "[SOS]", "[EOS]", "[PAD]"], min_frequency=2
|
| 124 |
+
)
|
| 125 |
+
tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
|
| 126 |
+
tokenizer.save(str(tokenizer_path))
|
| 127 |
+
else:
|
| 128 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 129 |
+
return tokenizer
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_dataset(config):
|
| 133 |
+
dataset_raw = load_dataset(
|
| 134 |
+
"opus_books", f"{config['lang_src']}-{config['lang_target']}", split="train"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
tokenizer_src = get_or_build_tokenizer(config, dataset_raw, config["lang_src"])
|
| 138 |
+
tokenizer_target = get_or_build_tokenizer(
|
| 139 |
+
config, dataset_raw, config["lang_target"]
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Split the dataset into training and validation
|
| 143 |
+
train_dataset_size = int(0.9 * len(dataset_raw))
|
| 144 |
+
validation_dataset_size = len(dataset_raw) - train_dataset_size
|
| 145 |
+
|
| 146 |
+
train_dataset_raw, validation_dataset_raw = random_split(
|
| 147 |
+
dataset_raw, [train_dataset_size, validation_dataset_size]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Initialize the classes
|
| 151 |
+
train_dataset = BilingualDataset(
|
| 152 |
+
train_dataset_raw,
|
| 153 |
+
tokenizer_src,
|
| 154 |
+
tokenizer_target,
|
| 155 |
+
config["lang_src"],
|
| 156 |
+
config["lang_target"],
|
| 157 |
+
config["seq_len"],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
validation_dataset = BilingualDataset(
|
| 161 |
+
validation_dataset_raw,
|
| 162 |
+
tokenizer_src,
|
| 163 |
+
tokenizer_target,
|
| 164 |
+
config["lang_src"],
|
| 165 |
+
config["lang_target"],
|
| 166 |
+
config["seq_len"],
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Calculate the max_len
|
| 170 |
+
max_len_src = 0
|
| 171 |
+
max_len_target = 0
|
| 172 |
+
|
| 173 |
+
for item in dataset_raw:
|
| 174 |
+
src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
|
| 175 |
+
target_ids = tokenizer_src.encode(
|
| 176 |
+
item["translation"][config["lang_target"]]
|
| 177 |
+
).ids
|
| 178 |
+
|
| 179 |
+
max_len_src = max(len(src_ids), max_len_src)
|
| 180 |
+
max_len_target = max(len(target_ids), max_len_target)
|
| 181 |
+
|
| 182 |
+
train_dataloader = DataLoader(
|
| 183 |
+
train_dataset, batch_size=config["batch_size"], shuffle=True
|
| 184 |
+
)
|
| 185 |
+
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=True)
|
| 186 |
+
|
| 187 |
+
return train_dataloader, validation_dataloader, tokenizer_src, tokenizer_target
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def get_model(config, vocab_src_len, vocab_target_length) -> Transformer:
|
| 191 |
+
model = build_transformer(
|
| 192 |
+
vocab_src_len,
|
| 193 |
+
vocab_target_length,
|
| 194 |
+
config["seq_len"],
|
| 195 |
+
config["seq_len"],
|
| 196 |
+
d_model=config["d_model"],
|
| 197 |
+
N=4,
|
| 198 |
+
head=4,
|
| 199 |
+
dropout=0.1,
|
| 200 |
+
d_ff=256,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
return model
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def train_model(config) -> None:
|
| 207 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 208 |
+
device = torch.device(device)
|
| 209 |
+
|
| 210 |
+
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
|
| 211 |
+
|
| 212 |
+
train_dataloader, validation_dataloader, tokenizer_src, tokenizer_target = (
|
| 213 |
+
get_dataset(config)
|
| 214 |
+
)
|
| 215 |
+
model = get_model(
|
| 216 |
+
config, tokenizer_src.get_vocab_size(), tokenizer_target.get_vocab_size()
|
| 217 |
+
).to(device)
|
| 218 |
+
|
| 219 |
+
# Adam's optimizer
|
| 220 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
|
| 221 |
+
initial_epoch = 0
|
| 222 |
+
global_step = 0
|
| 223 |
+
|
| 224 |
+
if config["preload"]:
|
| 225 |
+
model_filename = get_weights_file_path(config, config["preload"])
|
| 226 |
+
state = torch.load(model_filename)
|
| 227 |
+
initial_epoch = state["epoch"] + 1
|
| 228 |
+
optimizer.load_state_dict(state["optimizer_state_dict"])
|
| 229 |
+
global_step = state["global_step"]
|
| 230 |
+
|
| 231 |
+
# Loss functions
|
| 232 |
+
loss_fn = nn.CrossEntropyLoss(
|
| 233 |
+
ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
|
| 234 |
+
).to(device)
|
| 235 |
+
|
| 236 |
+
for epoch in range(initial_epoch, config["num_epochs"]):
|
| 237 |
+
batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch : {epoch:02d}")
|
| 238 |
+
for batch in batch_iterator:
|
| 239 |
+
model.train()
|
| 240 |
+
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
|
| 241 |
+
decoder_input = batch["decoder_input"].to(device) # (B, seq_len)
|
| 242 |
+
encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, seq_len)
|
| 243 |
+
decoder_mask = batch["decoder_mask"].to(device) # (B, 1, seq_len, seq_len)
|
| 244 |
+
|
| 245 |
+
encoder_output = model.encode(
|
| 246 |
+
encoder_input, encoder_mask
|
| 247 |
+
) # (B, seq_len, d_model)
|
| 248 |
+
decoder_output = model.decode(
|
| 249 |
+
encoder_output, encoder_mask, decoder_input, decoder_mask
|
| 250 |
+
) # (B, seq_len, d_model)
|
| 251 |
+
proj_output = model.projection(decoder_output) # (B, seq_len, vocab_size)
|
| 252 |
+
|
| 253 |
+
label = batch["label"].to(device) # (B, seq_len)
|
| 254 |
+
|
| 255 |
+
# Compare the expected output with the label
|
| 256 |
+
loss = loss_fn(
|
| 257 |
+
proj_output.view(-1, tokenizer_target.get_vocab_size()), label.view(-1)
|
| 258 |
+
)
|
| 259 |
+
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
|
| 260 |
+
|
| 261 |
+
# Back Propogation
|
| 262 |
+
loss.backward()
|
| 263 |
+
optimizer.step()
|
| 264 |
+
optimizer.zero_grad(set_to_none=True)
|
| 265 |
+
|
| 266 |
+
global_step += 1
|
| 267 |
+
|
| 268 |
+
# Inference after each epoch to see the results
|
| 269 |
+
run_validation(
|
| 270 |
+
model,
|
| 271 |
+
validation_dataloader,
|
| 272 |
+
tokenizer_src,
|
| 273 |
+
tokenizer_target,
|
| 274 |
+
config["seq_len"],
|
| 275 |
+
device,
|
| 276 |
+
lambda msg: batch_iterator.write(msg),
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
model_filename = get_weights_file_path(config, f"{epoch:02d}")
|
| 280 |
+
torch.save(
|
| 281 |
+
{
|
| 282 |
+
"epoch": epoch,
|
| 283 |
+
"model_state_dict": model.state_dict(),
|
| 284 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 285 |
+
"global_step": global_step,
|
| 286 |
+
},
|
| 287 |
+
model_filename,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
if __name__ == "__main__":
|
| 292 |
+
warnings.filterwarnings("ignore")
|
| 293 |
+
config = get_config()
|
| 294 |
+
train_model(config)
|
tokenizer_en.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_it.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
weights/tmodel_19.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb36d9b55db46492ea5335b961e4008878aa71496e89b29b6147fc329e89d1fb
|
| 3 |
+
size 551199243
|