Commit ·
39a7504
1
Parent(s): 552d382
Added models and code
Browse files- config.yaml +15 -0
- config_big.yaml +15 -0
- dataset.py +138 -0
- hub.py +19 -0
- inference.py +95 -0
- main.py +83 -0
- model.py +291 -0
- models/model_epoch_14.pth +3 -0
- models/model_epoch_15.pth +3 -0
- models/model_epoch_8.pth +3 -0
- models/model_epoch_9.pth +3 -0
- test.py +93 -0
- tokenizers/tokenizers_eng.json +0 -0
- tokenizers/tokenizers_hindi.json +0 -0
- train.py +243 -0
- util.py +208 -0
config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 16
|
| 2 |
+
d_model: 512
|
| 3 |
+
dataset_path: data/english2hindi_data.json
|
| 4 |
+
epochs: 15
|
| 5 |
+
learning_rate: 0.0003
|
| 6 |
+
model_directory: models
|
| 7 |
+
num_enc_dec_blocks: 6
|
| 8 |
+
num_of_heads: 8
|
| 9 |
+
resume_training: true
|
| 10 |
+
seq_len: 281
|
| 11 |
+
src_lang: en_text
|
| 12 |
+
src_tokenizer_file: tokenizers/tokenizers_eng.json
|
| 13 |
+
tgt_lang: hi_text
|
| 14 |
+
tgt_tokenizer_file: tokenizers/tokenizers_hindi.json
|
| 15 |
+
warmup_steps: 4000
|
config_big.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 32
|
| 2 |
+
d_model: 520
|
| 3 |
+
dataset_path: data/english2hindi_data.json
|
| 4 |
+
epochs: 5
|
| 5 |
+
learning_rate: 0.0012
|
| 6 |
+
model_directory: models
|
| 7 |
+
num_enc_dec_blocks: 8
|
| 8 |
+
num_of_heads: 10
|
| 9 |
+
resume_training: true
|
| 10 |
+
seq_len: 281
|
| 11 |
+
src_lang: en_text
|
| 12 |
+
src_tokenizer_file: tokenizers/tokenizers_eng.json
|
| 13 |
+
tgt_lang: hi_text
|
| 14 |
+
tgt_tokenizer_file: tokenizers/tokenizers_hindi.json
|
| 15 |
+
warmup_steps: 4000
|
dataset.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset
|
| 3 |
+
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class English2HindiDataset(Dataset):
|
| 9 |
+
def __init__(
|
| 10 |
+
self, data, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len
|
| 11 |
+
):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.seq_len = seq_len
|
| 14 |
+
self.tokenizer_src = tokenizer_src
|
| 15 |
+
self.tokenizer_tgt = tokenizer_tgt
|
| 16 |
+
self.src_lang = src_lang
|
| 17 |
+
self.tgt_lang = tgt_lang
|
| 18 |
+
|
| 19 |
+
self.data = data
|
| 20 |
+
self.sos_token = torch.tensor(
|
| 21 |
+
[tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64
|
| 22 |
+
)
|
| 23 |
+
self.eos_token = torch.tensor(
|
| 24 |
+
[tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64
|
| 25 |
+
)
|
| 26 |
+
self.pad_token = torch.tensor(
|
| 27 |
+
[tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Why int64?
|
| 31 |
+
# When passing token indices to an nn.Embedding layer, PyTorch expects torch.int64 (or torch.long)
|
| 32 |
+
|
| 33 |
+
def __len__(self):
|
| 34 |
+
return len(self.data)
|
| 35 |
+
|
| 36 |
+
def causal_mask(self,size):
|
| 37 |
+
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
|
| 38 |
+
return mask == 0
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, idx):
|
| 41 |
+
trans_pairs = self.data[idx]
|
| 42 |
+
|
| 43 |
+
src_text = trans_pairs["en_text"]
|
| 44 |
+
tgt_text = trans_pairs["hi_text"]
|
| 45 |
+
# here we first get english_text and hindi text which was in dictionary.
|
| 46 |
+
|
| 47 |
+
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
|
| 48 |
+
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
|
| 49 |
+
|
| 50 |
+
# we then use the two tokenizers to encode the text into input_ids.
|
| 51 |
+
|
| 52 |
+
# For every sentence for example : I am sai - The input to the enocoder will be
|
| 53 |
+
## <SOS> I am Sai <EOS>
|
| 54 |
+
### but because of variable length sequences we need to add padding.
|
| 55 |
+
### how do we do it ? we take the longest sentence in the dataset add 30 to it and that will gives us the seq_len
|
| 56 |
+
### So now we add padding to every sentence and make it similar lengths.
|
| 57 |
+
|
| 58 |
+
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2
|
| 59 |
+
|
| 60 |
+
# -2 Because of the <SOS> and <EOS>
|
| 61 |
+
|
| 62 |
+
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
|
| 63 |
+
|
| 64 |
+
# since the input to the decoder will only consist of <SOS>
|
| 65 |
+
# the target label will have the <EOS>
|
| 66 |
+
|
| 67 |
+
encoder_input = torch.cat(
|
| 68 |
+
[
|
| 69 |
+
self.sos_token,
|
| 70 |
+
torch.tensor(enc_input_tokens, dtype=torch.int64),
|
| 71 |
+
self.eos_token,
|
| 72 |
+
torch.tensor(
|
| 73 |
+
[self.pad_token] * enc_num_padding_tokens, dtype=torch.int64
|
| 74 |
+
),
|
| 75 |
+
],
|
| 76 |
+
dim=0,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
decoder_input = torch.cat(
|
| 80 |
+
[
|
| 81 |
+
self.sos_token,
|
| 82 |
+
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 83 |
+
torch.tensor(
|
| 84 |
+
[self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
|
| 85 |
+
),
|
| 86 |
+
],
|
| 87 |
+
dim=0,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# why dim =0 because everything 1d so they are stacked one after other.
|
| 91 |
+
|
| 92 |
+
label = torch.cat(
|
| 93 |
+
[
|
| 94 |
+
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 95 |
+
self.eos_token,
|
| 96 |
+
torch.tensor(
|
| 97 |
+
[self.pad_token]* dec_num_padding_tokens, dtype=torch.int64
|
| 98 |
+
),
|
| 99 |
+
],
|
| 100 |
+
dim=0,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Label is always tgt language so we give the the decoder tokens
|
| 104 |
+
|
| 105 |
+
encoder_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() # (1, 1, seq_len)
|
| 106 |
+
decoder_mask= (decoder_input != self.pad_token).unsqueeze(0).int() & self.causal_mask(decoder_input.size(0)) # (1, seq_len) & (1, seq_len, seq_len),
|
| 107 |
+
|
| 108 |
+
### didnt understand this at all.
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
return {
|
| 112 |
+
"encoder_input": encoder_input,
|
| 113 |
+
"decoder_input": decoder_input,
|
| 114 |
+
"encoder_mask": encoder_mask,
|
| 115 |
+
"decoder_mask": decoder_mask,
|
| 116 |
+
"label": label,
|
| 117 |
+
"src_text": src_text,
|
| 118 |
+
"tgt_text": tgt_text,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class English2HindiDatasetTest(Dataset):
|
| 124 |
+
def __init__(
|
| 125 |
+
self, json_path,
|
| 126 |
+
):
|
| 127 |
+
super().__init__()
|
| 128 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 129 |
+
self.data = json.load(f)
|
| 130 |
+
|
| 131 |
+
def __len__(self):
|
| 132 |
+
return len(self.data)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
trans_pairs = self.data[idx]
|
| 138 |
+
return trans_pairs
|
hub.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import create_repo
|
| 2 |
+
from huggingface_hub import HfApi, HfFolder, Repository
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
repo_url = HfApi().create_repo(
|
| 6 |
+
name="thecr7guy/trainin_transformers",
|
| 7 |
+
exist_ok=True
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import HfApi, HfFolder, Repository
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
model_dir = "./models"
|
| 14 |
+
|
| 15 |
+
repo = Repository(local_dir=model_dir, clone_from="thecr7guy/trainin_transformers")
|
| 16 |
+
|
| 17 |
+
repo.git_add()
|
| 18 |
+
repo.git_commit("Added models")
|
| 19 |
+
repo.git_push()
|
inference.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from model import build_transformer
|
| 4 |
+
from util import create_resources
|
| 5 |
+
import torch
|
| 6 |
+
import sys
|
| 7 |
+
import yaml
|
| 8 |
+
import sacrebleu
|
| 9 |
+
|
| 10 |
+
def translate(sentence: str):
|
| 11 |
+
|
| 12 |
+
with open("config.yaml", "r") as file:
|
| 13 |
+
config = yaml.safe_load(file)
|
| 14 |
+
|
| 15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
print("Using device:", device)
|
| 17 |
+
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
|
| 18 |
+
|
| 19 |
+
src_vocab_size = tokenizer_src.get_vocab_size()
|
| 20 |
+
tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
model = build_transformer(
|
| 24 |
+
src_vocab_size,
|
| 25 |
+
tgt_vocab_size,
|
| 26 |
+
config["seq_len"],
|
| 27 |
+
config["seq_len"],
|
| 28 |
+
config["num_enc_dec_blocks"],
|
| 29 |
+
config["num_of_heads"],
|
| 30 |
+
config["d_model"]
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
model = model.to(device)
|
| 34 |
+
#model_filename = "models/big_models_res/model_epoch_2.pth"
|
| 35 |
+
model_filename = "models/model_epoch_15.pth"
|
| 36 |
+
state = torch.load(model_filename)
|
| 37 |
+
model.load_state_dict(state['model_state_dict'])
|
| 38 |
+
|
| 39 |
+
model.eval()
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
source = tokenizer_src.encode(sentence)
|
| 42 |
+
print(source,source.ids)
|
| 43 |
+
|
| 44 |
+
source = torch.cat([
|
| 45 |
+
torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
|
| 46 |
+
torch.tensor(source.ids, dtype=torch.int64),
|
| 47 |
+
torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
|
| 48 |
+
torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config["seq_len"] - len(source.ids) - 2), dtype=torch.int64)
|
| 49 |
+
], dim=0)
|
| 50 |
+
|
| 51 |
+
source = source.to(device)
|
| 52 |
+
source = source.unsqueeze(0)
|
| 53 |
+
|
| 54 |
+
print(source.shape)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
|
| 59 |
+
|
| 60 |
+
encoder_output = model.encode(source, source_mask)
|
| 61 |
+
decoder_input = torch.full((1, 1), tokenizer_tgt.token_to_id('[SOS]'),
|
| 62 |
+
dtype=torch.long, device=device)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
while decoder_input.size(1) < config["seq_len"]:
|
| 66 |
+
|
| 67 |
+
decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
|
| 68 |
+
diagonal=1).to(device, dtype=torch.int)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)
|
| 72 |
+
prob = model.project(out[:, -1])
|
| 73 |
+
_, next_word = torch.max(prob, dim=1)
|
| 74 |
+
|
| 75 |
+
next_token = torch.full((1, 1), next_word.item(), dtype=torch.long, device=device)
|
| 76 |
+
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| 77 |
+
|
| 78 |
+
print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
|
| 79 |
+
|
| 80 |
+
if next_word.item() == tokenizer_tgt.token_to_id('[EOS]'):
|
| 81 |
+
break
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
return tokenizer_tgt.decode(decoder_input[0].tolist())
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
#a = translate("Why does Earth have only one moon, while other planets have many?")
|
| 89 |
+
a = translate("Defending champion Kolkata Knight Riders (KKR) hosts Royal Challengers Bengaluru (RCB) at the Eden Gardens in the Indian Premier League 2025 opener on Saturday.")
|
| 90 |
+
a = translate("Which South American country is home to the Amazon Rainforest and the Christ the Redeemer statue?")
|
| 91 |
+
# a = translate("Imagine you are an astronaut stepping onto Mars for the first time. Write a monologue expressing your emotions and observations.")
|
| 92 |
+
# a = translate("The theory of evolution, proposed by Charles Darwin, explains the process by which species of organisms change over time through natural selection. While the theory has been widely accepted in the scientific community, it continues to spark debates in various social and religious contexts. Discuss how the theory of evolution has shaped our understanding of human origins and the controversies that surround it.")
|
| 93 |
+
|
| 94 |
+
with open("output.txt", "a") as w:
|
| 95 |
+
w.write(f"\n{a}")
|
main.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchinfo import summary
|
| 2 |
+
from model import build_transformer
|
| 3 |
+
from util import create_resources
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
with open("config.yaml", "r") as file:
|
| 13 |
+
config = yaml.safe_load(file)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
|
| 17 |
+
src_vocab_size = tokenizer_src.get_vocab_size()
|
| 18 |
+
tgt_vocab_size = tokenizer_src.get_vocab_size()
|
| 19 |
+
|
| 20 |
+
model = build_transformer(
|
| 21 |
+
src_vocab_size,
|
| 22 |
+
tgt_vocab_size,
|
| 23 |
+
config["seq_len"],
|
| 24 |
+
config["seq_len"],
|
| 25 |
+
config["num_enc_dec_blocks"],
|
| 26 |
+
config["num_of_heads"],
|
| 27 |
+
config["d_model"]
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
batch_size = config["batch_size"]
|
| 31 |
+
num_epochs = config["epochs"] if "epochs" in config else 10
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
model = model.to(device)
|
| 36 |
+
|
| 37 |
+
criterion = loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
|
| 38 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"],eps=1e-9)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_checkpoint(epoch, model, optimizer, path):
|
| 42 |
+
torch.save({
|
| 43 |
+
"epoch": epoch,
|
| 44 |
+
"model_state_dict": model.state_dict(),
|
| 45 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 46 |
+
}, path)
|
| 47 |
+
print(f"Checkpoint saved at epoch {epoch} to {path}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
|
| 51 |
+
checkpoint = torch.load(path, map_location=map_location)
|
| 52 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 53 |
+
if optimizer and "optimizer_state_dict" in checkpoint:
|
| 54 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 55 |
+
start_epoch = checkpoint.get("epoch", 0)
|
| 56 |
+
print(f"Loaded checkpoint from epoch {start_epoch}")
|
| 57 |
+
return start_epoch
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def train_one_epoch(device):
|
| 61 |
+
model.train()
|
| 62 |
+
running_loss = 0.0
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def train_model(model):
|
| 68 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 69 |
+
if (device == 'cuda'):
|
| 70 |
+
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
|
| 71 |
+
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
|
| 72 |
+
|
| 73 |
+
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
|
| 74 |
+
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
model.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class InputEmbeddings(torch.nn.Module):
|
| 6 |
+
def __init__(self, d_model, vocab_size):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.d_model = d_model
|
| 9 |
+
self.vocab_size = vocab_size
|
| 10 |
+
self.embeddingss = torch.nn.Embedding(vocab_size, d_model)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
return self.embeddingss(x) * math.sqrt(self.d_model)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PositionalEncoding(torch.nn.Module):
|
| 17 |
+
def __init__(self, d_model, seq_len, dropout):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.d_model = d_model
|
| 20 |
+
self.seq_len = seq_len
|
| 21 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 22 |
+
pe = torch.zeros(self.seq_len, self.d_model)
|
| 23 |
+
|
| 24 |
+
for i in range(self.seq_len):
|
| 25 |
+
for j in range(self.d_model):
|
| 26 |
+
denom = torch.pow(torch.tensor(10000.0), (2 * j) / self.d_model)
|
| 27 |
+
num = torch.tensor(float(i))
|
| 28 |
+
if j % 2 == 0:
|
| 29 |
+
pe[i, j] = torch.sin(num / denom)
|
| 30 |
+
else:
|
| 31 |
+
pe[i, j] = torch.cos(num / denom)
|
| 32 |
+
|
| 33 |
+
pe = pe.unsqueeze(0)
|
| 34 |
+
print(pe.shape)
|
| 35 |
+
self.register_buffer("pe", pe)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
|
| 39 |
+
return self.dropout(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LayerNormm(torch.nn.Module):
|
| 43 |
+
def __init__(self, features):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.layer_norm = torch.nn.LayerNorm(features, eps=1e-5)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.layer_norm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class FeedForward(torch.nn.Module):
|
| 52 |
+
def __init__(self, d_model, dff, dropout):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.linear_1 = torch.nn.Linear(d_model, dff)
|
| 55 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 56 |
+
self.linear_2 = torch.nn.Linear(dff, d_model)
|
| 57 |
+
self.activation = torch.nn.ReLU()
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x = self.linear_1(x)
|
| 61 |
+
x = self.activation(x)
|
| 62 |
+
x = self.dropout(x)
|
| 63 |
+
x = self.linear_2(x)
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class MHA(torch.nn.Module):
|
| 68 |
+
def __init__(self, d_model, number_of_heads, dropout):
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 72 |
+
self.d_model = d_model
|
| 73 |
+
self.noh = number_of_heads
|
| 74 |
+
|
| 75 |
+
self.dk = self.d_model // self.noh
|
| 76 |
+
|
| 77 |
+
self.wq = torch.nn.Linear(d_model, d_model)
|
| 78 |
+
self.wk = torch.nn.Linear(d_model, d_model)
|
| 79 |
+
self.wv = torch.nn.Linear(d_model, d_model)
|
| 80 |
+
self.wo = torch.nn.Linear(d_model, d_model)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def calculate_self_attention(qprime, kprime, vprime, mask, dropout):
|
| 84 |
+
dk = qprime.shape[-1]
|
| 85 |
+
attention_scores = (qprime @ kprime.transpose(-2, -1)) / math.sqrt(dk)
|
| 86 |
+
|
| 87 |
+
if mask is not None:
|
| 88 |
+
attention_scores.masked_fill_(mask == 0, -1e9)
|
| 89 |
+
|
| 90 |
+
attention_scores = attention_scores.softmax(dim=-1)
|
| 91 |
+
# why last dim ?
|
| 92 |
+
if dropout is not None:
|
| 93 |
+
attention_scores = dropout(attention_scores)
|
| 94 |
+
|
| 95 |
+
return (attention_scores @ vprime), attention_scores
|
| 96 |
+
|
| 97 |
+
def forward(self, q, k, v, mask):
|
| 98 |
+
qprime = self.wq(q)
|
| 99 |
+
# (batch,seq_length,dmodel)
|
| 100 |
+
kprime = self.wk(k)
|
| 101 |
+
# (batch,seq_length,dmodel)
|
| 102 |
+
vprime = self.wv(v)
|
| 103 |
+
# (batch,seq_length,dmodel)
|
| 104 |
+
|
| 105 |
+
qprime = qprime.view(qprime.shape[0], qprime.shape[1], self.noh, self.dk)
|
| 106 |
+
# (batch,seq_length,dmodel) =>(batch,seq_length,noh,dk)
|
| 107 |
+
qprime = qprime.transpose(1, 2)
|
| 108 |
+
# (batch,seq_length,noh,dk) => (batch,noh,seq_length,dk)
|
| 109 |
+
|
| 110 |
+
kprime = kprime.view(kprime.shape[0], kprime.shape[1], self.noh, self.dk)
|
| 111 |
+
kprime = kprime.transpose(1, 2)
|
| 112 |
+
|
| 113 |
+
vprime = vprime.view(vprime.shape[0], vprime.shape[1], self.noh, self.dk)
|
| 114 |
+
vprime = vprime.transpose(1, 2)
|
| 115 |
+
|
| 116 |
+
x, attention_scores = MHA.calculate_self_attention(
|
| 117 |
+
qprime, kprime, vprime, mask, self.dropout
|
| 118 |
+
)
|
| 119 |
+
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.noh * self.dk)
|
| 120 |
+
return self.wo(x)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class SkipConnection(torch.nn.Module):
|
| 124 |
+
def __init__(self, features, dropout):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 127 |
+
self.layernorm = LayerNormm(features)
|
| 128 |
+
|
| 129 |
+
def forward(self, x, sublayer):
|
| 130 |
+
return x + self.dropout(sublayer(self.layernorm(x)))
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class EncoderBlock(torch.nn.Module):
|
| 134 |
+
def __init__(self, features, mha_block, feedforward_block, dropout):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.attention_block = mha_block
|
| 137 |
+
self.feedforward_block = feedforward_block
|
| 138 |
+
self.skip_connections = torch.nn.ModuleList(
|
| 139 |
+
[SkipConnection(features, dropout) for _ in range(2)]
|
| 140 |
+
)
|
| 141 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 142 |
+
|
| 143 |
+
def forward(self, x, src_mask):
|
| 144 |
+
x = self.skip_connections[0](
|
| 145 |
+
x, lambda x: self.attention_block(x, x, x, src_mask)
|
| 146 |
+
)
|
| 147 |
+
x = self.skip_connections[1](x, self.feedforward_block)
|
| 148 |
+
return x
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class Encoder(torch.nn.Module):
|
| 152 |
+
def __init__(self, features: int, layers: torch.nn.ModuleList) -> None:
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.layers = layers
|
| 155 |
+
self.norm = LayerNormm(features)
|
| 156 |
+
|
| 157 |
+
def forward(self, x, mask):
|
| 158 |
+
for layer in self.layers:
|
| 159 |
+
x = layer(x, mask)
|
| 160 |
+
return self.norm(x)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class DecoderBlock(torch.nn.Module):
|
| 164 |
+
def __init__(self, features, mha_block, mha_block2, feedforward_block, dropout):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.attention_block = mha_block
|
| 167 |
+
self.cross_attention_block = mha_block2
|
| 168 |
+
self.feedforward_block = feedforward_block
|
| 169 |
+
self.skip_connections = torch.nn.ModuleList(
|
| 170 |
+
[SkipConnection(features, dropout) for _ in range(3)]
|
| 171 |
+
)
|
| 172 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 173 |
+
|
| 174 |
+
def forward(self, x, enc_output, src_mask, tgt_mask):
|
| 175 |
+
x = self.skip_connections[0](
|
| 176 |
+
x, lambda x: self.attention_block(x, x, x, tgt_mask)
|
| 177 |
+
)
|
| 178 |
+
x = self.skip_connections[1](
|
| 179 |
+
x, lambda x: self.cross_attention_block(x, enc_output, enc_output, src_mask)
|
| 180 |
+
)
|
| 181 |
+
x = self.skip_connections[2](x, self.feedforward_block)
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Decoder(torch.nn.Module):
|
| 186 |
+
|
| 187 |
+
def __init__(self, features: int, layers: torch.nn.ModuleList) -> None:
|
| 188 |
+
super().__init__()
|
| 189 |
+
self.layers = layers
|
| 190 |
+
self.norm = LayerNormm(features)
|
| 191 |
+
|
| 192 |
+
def forward(self, x, encoder_output, src_mask, tgt_mask):
|
| 193 |
+
for layer in self.layers:
|
| 194 |
+
x = layer(x, encoder_output, src_mask, tgt_mask)
|
| 195 |
+
return self.norm(x)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class ProjectionLayer(torch.nn.Module):
|
| 199 |
+
|
| 200 |
+
def __init__(self, d_model, vocab_size) -> None:
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.proj = torch.nn.Linear(d_model, vocab_size)
|
| 203 |
+
|
| 204 |
+
def forward(self, x) -> None:
|
| 205 |
+
return self.proj(x)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class Transformer(torch.nn.Module):
|
| 209 |
+
|
| 210 |
+
def __init__(
|
| 211 |
+
self,
|
| 212 |
+
encoder,
|
| 213 |
+
decoder,
|
| 214 |
+
src_pos_enc,
|
| 215 |
+
tgt_pos_enc,
|
| 216 |
+
src_emb,
|
| 217 |
+
tgt_emb,
|
| 218 |
+
projection_layer,
|
| 219 |
+
) -> None:
|
| 220 |
+
super().__init__()
|
| 221 |
+
self.encoder = encoder
|
| 222 |
+
self.decoder = decoder
|
| 223 |
+
self.src_pos_enc = src_pos_enc
|
| 224 |
+
self.tgt_pos_enc = tgt_pos_enc
|
| 225 |
+
self.src_emb = src_emb
|
| 226 |
+
self.tgt_emb = tgt_emb
|
| 227 |
+
self.projection_layer = projection_layer
|
| 228 |
+
|
| 229 |
+
def encode(self, src, src_mask):
|
| 230 |
+
src = self.src_emb(src)
|
| 231 |
+
src = self.src_pos_enc(src)
|
| 232 |
+
x = self.encoder(src, src_mask)
|
| 233 |
+
return x
|
| 234 |
+
|
| 235 |
+
def decode(self, tgt, enc_output, src_mask, tgt_mask):
|
| 236 |
+
tgt = self.tgt_emb(tgt)
|
| 237 |
+
tgt = self.tgt_pos_enc(tgt)
|
| 238 |
+
x = self.decoder(tgt, enc_output, src_mask, tgt_mask)
|
| 239 |
+
return x
|
| 240 |
+
|
| 241 |
+
def project(self, x):
|
| 242 |
+
x = self.projection_layer(x)
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def build_transformer(
|
| 246 |
+
src_vocab_size,
|
| 247 |
+
tgt_vocab_size,
|
| 248 |
+
src_seq_len,
|
| 249 |
+
tgt_seq_len,
|
| 250 |
+
nlayers=6,
|
| 251 |
+
noh=8,
|
| 252 |
+
d_model=512,
|
| 253 |
+
dropout=0.1,
|
| 254 |
+
dff=2048,
|
| 255 |
+
):
|
| 256 |
+
src_emb = InputEmbeddings(d_model, src_vocab_size)
|
| 257 |
+
tgt_emb = InputEmbeddings(d_model, tgt_vocab_size)
|
| 258 |
+
|
| 259 |
+
src_pos_enc = PositionalEncoding(d_model, src_seq_len, dropout)
|
| 260 |
+
tgt_pos_enc = PositionalEncoding(d_model, tgt_seq_len, dropout)
|
| 261 |
+
|
| 262 |
+
enc_blocks = []
|
| 263 |
+
for i in range(0, nlayers):
|
| 264 |
+
mha = MHA(d_model, noh, dropout)
|
| 265 |
+
ff = FeedForward(d_model, dff, dropout)
|
| 266 |
+
enc_block = EncoderBlock(d_model, mha, ff, dropout)
|
| 267 |
+
enc_blocks.append(enc_block)
|
| 268 |
+
|
| 269 |
+
encoder = Encoder(d_model, torch.nn.ModuleList(enc_blocks))
|
| 270 |
+
|
| 271 |
+
dec_blocks = []
|
| 272 |
+
for i in range(0, nlayers):
|
| 273 |
+
mha = MHA(d_model, noh, dropout)
|
| 274 |
+
mha2 = MHA(d_model, noh, dropout)
|
| 275 |
+
ff = FeedForward(d_model, dff, dropout)
|
| 276 |
+
dec_block = DecoderBlock(d_model, mha, mha2, ff, dropout)
|
| 277 |
+
dec_blocks.append(dec_block)
|
| 278 |
+
|
| 279 |
+
decoder = Decoder(d_model, torch.nn.ModuleList(dec_blocks))
|
| 280 |
+
|
| 281 |
+
proj = ProjectionLayer(d_model, tgt_vocab_size)
|
| 282 |
+
|
| 283 |
+
transformer = Transformer(
|
| 284 |
+
encoder, decoder, src_pos_enc, tgt_pos_enc, src_emb, tgt_emb, proj
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
for p in transformer.parameters():
|
| 288 |
+
if p.dim() > 1:
|
| 289 |
+
torch.nn.init.xavier_uniform_(p)
|
| 290 |
+
|
| 291 |
+
return transformer
|
models/model_epoch_14.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:883fffd40352bf4f925da84a67882fad7284d4cd6ca6f732fd3a9bc7f1ca12fe
|
| 3 |
+
size 1425518098
|
models/model_epoch_15.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c473c1bcd9fa94a37c5b2e0f0196a77ddc00774af1a98ee57c7cd32852e1dab
|
| 3 |
+
size 1425518098
|
models/model_epoch_8.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1b5abd48289c13b380f3cf2288591daf8be891709a56072b1962a41060739224
|
| 3 |
+
size 1425514236
|
models/model_epoch_9.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b94e9d6c6818e8b9ac23bf37aa303689e038f58a0da4b5bef7c8dcc33645ca3b
|
| 3 |
+
size 1425517052
|
test.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from model import build_transformer
|
| 3 |
+
from util import create_resources
|
| 4 |
+
import torch
|
| 5 |
+
import sys
|
| 6 |
+
import yaml
|
| 7 |
+
|
| 8 |
+
def translate(sentence: str):
|
| 9 |
+
|
| 10 |
+
with open("config.yaml", "r") as file:
|
| 11 |
+
config = yaml.safe_load(file)
|
| 12 |
+
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
print("Using device:", device)
|
| 15 |
+
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
|
| 16 |
+
|
| 17 |
+
src_vocab_size = tokenizer_src.get_vocab_size()
|
| 18 |
+
tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
model = build_transformer(
|
| 22 |
+
src_vocab_size,
|
| 23 |
+
tgt_vocab_size,
|
| 24 |
+
config["seq_len"],
|
| 25 |
+
config["seq_len"],
|
| 26 |
+
config["num_enc_dec_blocks"],
|
| 27 |
+
config["num_of_heads"],
|
| 28 |
+
config["d_model"]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
model = model.to(device)
|
| 32 |
+
model_filename = "models/model.pth"
|
| 33 |
+
state = torch.load(model_filename)
|
| 34 |
+
model.load_state_dict(state['model_state_dict'])
|
| 35 |
+
|
| 36 |
+
model.eval()
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
source = tokenizer_src.encode(sentence)
|
| 39 |
+
print(source,source.ids)
|
| 40 |
+
|
| 41 |
+
source = torch.cat([
|
| 42 |
+
torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
|
| 43 |
+
torch.tensor(source.ids, dtype=torch.int64),
|
| 44 |
+
torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
|
| 45 |
+
torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config["seq_len"] - len(source.ids) - 2), dtype=torch.int64)
|
| 46 |
+
], dim=0)
|
| 47 |
+
|
| 48 |
+
source = source.to(device)
|
| 49 |
+
source = source.unsqueeze(0)
|
| 50 |
+
|
| 51 |
+
print(source.shape)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
|
| 56 |
+
print(source_mask)
|
| 57 |
+
|
| 58 |
+
encoder_output = model.encode(source, source_mask)
|
| 59 |
+
print(encoder_output.shape)
|
| 60 |
+
decoder_input = torch.full((1, 1), tokenizer_tgt.token_to_id('[SOS]'),
|
| 61 |
+
dtype=torch.long, device=device)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
while decoder_input.size(1) < config["seq_len"]:
|
| 65 |
+
|
| 66 |
+
decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
|
| 67 |
+
diagonal=1).to(device, dtype=torch.int)
|
| 68 |
+
|
| 69 |
+
print("#######################")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
print(decoder_mask.shape)
|
| 73 |
+
out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)
|
| 74 |
+
prob = model.project(out[:, -1])
|
| 75 |
+
print(max(prob[0]))
|
| 76 |
+
print(min(prob[0]))
|
| 77 |
+
_, next_word = torch.max(prob, dim=1)
|
| 78 |
+
|
| 79 |
+
next_token = torch.full((1, 1), next_word.item(), dtype=torch.long, device=device)
|
| 80 |
+
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| 81 |
+
|
| 82 |
+
print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
|
| 83 |
+
|
| 84 |
+
if next_word.item() == tokenizer_tgt.token_to_id('[EOS]'):
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# return tokenizer_tgt.decode(decoder_input[0].tolist())
|
| 89 |
+
return 0
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
a = translate("My Name is sai and I love computers")
|
| 93 |
+
print(a)
|
tokenizers/tokenizers_eng.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizers/tokenizers_hindi.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchinfo import summary
|
| 2 |
+
from model import build_transformer
|
| 3 |
+
from util import create_resources
|
| 4 |
+
import yaml
|
| 5 |
+
import torch
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import os
|
| 8 |
+
import wandb
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
from matplotlib import font_manager
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
mangal_font_path = "Mangal.TTf"
|
| 15 |
+
devanagari_font = font_manager.FontProperties(fname=mangal_font_path)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class NoamScheduler:
|
| 19 |
+
def __init__(self, optimizer, d_model, warmup_steps):
|
| 20 |
+
self.optimizer = optimizer
|
| 21 |
+
self.d_model = d_model
|
| 22 |
+
self.warmup_steps = warmup_steps
|
| 23 |
+
self.step_num = 0
|
| 24 |
+
|
| 25 |
+
def step(self):
|
| 26 |
+
self.step_num += 1
|
| 27 |
+
lr = self.get_lr()
|
| 28 |
+
for param_group in self.optimizer.param_groups:
|
| 29 |
+
param_group["lr"] = lr
|
| 30 |
+
return lr
|
| 31 |
+
|
| 32 |
+
def get_lr(self):
|
| 33 |
+
step = max(self.step_num, 1)
|
| 34 |
+
arg1 = step ** (-0.5)
|
| 35 |
+
arg2 = step * (self.warmup_steps ** (-1.5))
|
| 36 |
+
return (self.d_model ** (-0.5)) * min(arg1, arg2)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Trainer:
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
model,
|
| 43 |
+
optimizer,
|
| 44 |
+
scheduler,
|
| 45 |
+
criterion,
|
| 46 |
+
device,
|
| 47 |
+
tokenizer_src,
|
| 48 |
+
tokenizer_tgt,
|
| 49 |
+
seq_len,
|
| 50 |
+
):
|
| 51 |
+
self.model = model
|
| 52 |
+
self.optimizer = optimizer
|
| 53 |
+
self.scheduler = scheduler
|
| 54 |
+
self.criterion = criterion
|
| 55 |
+
self.device = device
|
| 56 |
+
self.tgt_tokenizer = tokenizer_tgt
|
| 57 |
+
self.src_tokenizer = tokenizer_src
|
| 58 |
+
self.seq_len = seq_len
|
| 59 |
+
|
| 60 |
+
def train_epoch(self, dataloader):
|
| 61 |
+
self.model.train()
|
| 62 |
+
torch.cuda.empty_cache()
|
| 63 |
+
running_loss = 0.0
|
| 64 |
+
total_tokens = 0
|
| 65 |
+
progress_bar = tqdm(
|
| 66 |
+
enumerate(dataloader), desc="Training", total=len(dataloader)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
for batch_idx, batch in progress_bar:
|
| 70 |
+
|
| 71 |
+
encoder_input = batch["encoder_input"].to(self.device)
|
| 72 |
+
# Should be (1,seq_len) => (batch_size,seq_len)
|
| 73 |
+
decoder_input = batch["decoder_input"].to(self.device)
|
| 74 |
+
# Should be (1,seq_len) => (batch_size,seq_len)
|
| 75 |
+
|
| 76 |
+
encoder_mask = batch["encoder_mask"].to(self.device)
|
| 77 |
+
decoder_mask = batch["decoder_mask"].to(self.device)
|
| 78 |
+
|
| 79 |
+
encoder_output = self.model.encode(encoder_input, encoder_mask)
|
| 80 |
+
decoder_output = self.model.decode(
|
| 81 |
+
decoder_input, encoder_output, encoder_mask, decoder_mask
|
| 82 |
+
)
|
| 83 |
+
projection_output = self.model.project(decoder_output)
|
| 84 |
+
|
| 85 |
+
label = batch["label"].to(self.device)
|
| 86 |
+
|
| 87 |
+
loss = self.criterion(
|
| 88 |
+
projection_output.view(-1, self.tgt_tokenizer.get_vocab_size()),
|
| 89 |
+
label.view(-1),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
loss.backward()
|
| 93 |
+
|
| 94 |
+
self.optimizer.step()
|
| 95 |
+
current_lr = self.scheduler.step()
|
| 96 |
+
self.optimizer.zero_grad()
|
| 97 |
+
|
| 98 |
+
pad_id = 1
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
non_pad = label.ne(pad_id)
|
| 101 |
+
num_nonpad_tokens = non_pad.sum().item()
|
| 102 |
+
running_loss += loss.item() * num_nonpad_tokens
|
| 103 |
+
total_tokens += num_nonpad_tokens
|
| 104 |
+
|
| 105 |
+
if (batch_idx + 1) % 50 == 0:
|
| 106 |
+
wandb.log(
|
| 107 |
+
{
|
| 108 |
+
"batch_loss": loss.item(),
|
| 109 |
+
"learning_rate": current_lr,
|
| 110 |
+
"batch": batch_idx + 1,
|
| 111 |
+
}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
epoch_loss = running_loss / total_tokens if total_tokens > 0 else 0.0
|
| 115 |
+
return epoch_loss
|
| 116 |
+
|
| 117 |
+
def save_checkpoint(self, epoch, output_dir):
|
| 118 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 119 |
+
checkpoint = {
|
| 120 |
+
"epoch": epoch,
|
| 121 |
+
"model_state_dict": self.model.state_dict(),
|
| 122 |
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
| 123 |
+
"scheduler_state": self.scheduler.step_num,
|
| 124 |
+
}
|
| 125 |
+
torch.save(checkpoint, os.path.join(output_dir, f"model_epoch_{epoch}.pth"))
|
| 126 |
+
print(f"Checkpoint saved at epoch {epoch}")
|
| 127 |
+
|
| 128 |
+
def run(self, train_loader, epochs, output_dir, start_epoch=1):
|
| 129 |
+
for epoch in range(start_epoch, epochs + 1):
|
| 130 |
+
train_loss = self.train_epoch(train_loader)
|
| 131 |
+
current_lr = self.scheduler.get_lr()
|
| 132 |
+
|
| 133 |
+
wandb.log(
|
| 134 |
+
{"epoch": epoch, "train_loss": train_loss, "learning_rate": current_lr}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.save_checkpoint(epoch, output_dir)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def load_latest_checkpoint(model, optimizer, scheduler, model_directory, device):
|
| 141 |
+
if not os.path.isdir(model_directory):
|
| 142 |
+
return None, 1
|
| 143 |
+
checkpoint_files = []
|
| 144 |
+
for filename in os.listdir(model_directory):
|
| 145 |
+
if filename.endswith(".pth"):
|
| 146 |
+
match = re.search(r"model_epoch_(\d+)\.pth", filename)
|
| 147 |
+
if match:
|
| 148 |
+
epoch = int(match.group(1))
|
| 149 |
+
checkpoint_files.append((epoch, filename))
|
| 150 |
+
|
| 151 |
+
if not checkpoint_files:
|
| 152 |
+
return None, 1
|
| 153 |
+
|
| 154 |
+
# Get the checkpoint with the highest epoch number
|
| 155 |
+
latest_epoch, latest_filename = max(checkpoint_files, key=lambda x: x[0])
|
| 156 |
+
ckpt_path = os.path.join(model_directory, latest_filename)
|
| 157 |
+
ckpt = torch.load(ckpt_path, map_location=device)
|
| 158 |
+
|
| 159 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 160 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 161 |
+
scheduler.step_num = ckpt["scheduler_state"]
|
| 162 |
+
start_epoch = ckpt["epoch"] + 1
|
| 163 |
+
print(f"Resuming Training from epoch {ckpt['epoch']}")
|
| 164 |
+
return ckpt, start_epoch
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def main():
|
| 168 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 169 |
+
(
|
| 170 |
+
train_dataloader,
|
| 171 |
+
valid_dataloader,
|
| 172 |
+
test_dataloader,
|
| 173 |
+
tokenizer_src,
|
| 174 |
+
tokenizer_tgt,
|
| 175 |
+
) = create_resources()
|
| 176 |
+
src_vocab_size = tokenizer_src.get_vocab_size()
|
| 177 |
+
tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
| 178 |
+
|
| 179 |
+
with open("config.yaml", "r") as file:
|
| 180 |
+
config = yaml.safe_load(file)
|
| 181 |
+
|
| 182 |
+
run = wandb.init(
|
| 183 |
+
entity="training-transformers-vast",
|
| 184 |
+
project="AttentionTranslate-sai", config=config)
|
| 185 |
+
|
| 186 |
+
model = build_transformer(
|
| 187 |
+
src_vocab_size,
|
| 188 |
+
tgt_vocab_size,
|
| 189 |
+
config["seq_len"],
|
| 190 |
+
config["seq_len"],
|
| 191 |
+
config["num_enc_dec_blocks"],
|
| 192 |
+
config["num_of_heads"],
|
| 193 |
+
config["d_model"],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
model = model.to(device)
|
| 197 |
+
|
| 198 |
+
wandb.watch(model, log="all")
|
| 199 |
+
|
| 200 |
+
criterion = torch.nn.CrossEntropyLoss(
|
| 201 |
+
ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
|
| 202 |
+
).to(device)
|
| 203 |
+
|
| 204 |
+
optimizer = torch.optim.AdamW(
|
| 205 |
+
model.parameters(), lr=config["learning_rate"], betas=(0.9, 0.98), eps=1e-9
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
scheduler = NoamScheduler(optimizer, config["d_model"], config["warmup_steps"])
|
| 209 |
+
|
| 210 |
+
start_epoch = 1
|
| 211 |
+
|
| 212 |
+
if config["resume_training"]:
|
| 213 |
+
ckpt, start_epoch = load_latest_checkpoint(
|
| 214 |
+
model, optimizer, scheduler, config["model_directory"], device
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if start_epoch == 1:
|
| 218 |
+
print("Training from scratch.")
|
| 219 |
+
|
| 220 |
+
trainer = Trainer(
|
| 221 |
+
model,
|
| 222 |
+
optimizer,
|
| 223 |
+
scheduler,
|
| 224 |
+
criterion,
|
| 225 |
+
device,
|
| 226 |
+
tokenizer_src,
|
| 227 |
+
tokenizer_tgt,
|
| 228 |
+
config["seq_len"],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# test_data_subset = list(test_dataloader)
|
| 232 |
+
# one_percent = int(0.01 * len(test_data_subset))
|
| 233 |
+
# test_data_1_percent = test_data_subset[:one_percent]
|
| 234 |
+
|
| 235 |
+
trainer.run(
|
| 236 |
+
train_dataloader, config["epochs"], config["model_directory"], start_epoch
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
run.finish()
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
main()
|
util.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tokenizers import Tokenizer
|
| 2 |
+
from tokenizers.models import BPE
|
| 3 |
+
from tokenizers.trainers import BpeTrainer
|
| 4 |
+
from tokenizers.pre_tokenizers import (
|
| 5 |
+
WhitespaceSplit,
|
| 6 |
+
Punctuation,
|
| 7 |
+
Sequence as PreSequence,
|
| 8 |
+
)
|
| 9 |
+
from tokenizers.normalizers import NFD, Lowercase, Sequence, StripAccents
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import yaml
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from dataset import English2HindiDataset, English2HindiDatasetTest
|
| 14 |
+
from torch.utils.data import DataLoader
|
| 15 |
+
import json
|
| 16 |
+
import random
|
| 17 |
+
from sklearn.model_selection import train_test_split
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_config(config_path):
|
| 21 |
+
with open(config_path, "r") as file:
|
| 22 |
+
return yaml.safe_load(file)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_all_sentences(ds, lang):
|
| 26 |
+
for item in ds:
|
| 27 |
+
yield item[lang]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_or_build_tokenizer(tokenizer_path, ds, lang):
|
| 31 |
+
tokenizer_path = Path(tokenizer_path)
|
| 32 |
+
if not Path.exists(tokenizer_path):
|
| 33 |
+
tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
|
| 34 |
+
if lang == "en_text":
|
| 35 |
+
tokenizer.normalizer = Sequence(
|
| 36 |
+
[
|
| 37 |
+
NFD(),
|
| 38 |
+
StripAccents(),
|
| 39 |
+
Lowercase(),
|
| 40 |
+
]
|
| 41 |
+
)
|
| 42 |
+
else:
|
| 43 |
+
tokenizer.normalizer = Sequence(
|
| 44 |
+
[
|
| 45 |
+
NFD(),
|
| 46 |
+
Lowercase(),
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
tokenizer.pre_tokenizer = PreSequence([WhitespaceSplit(), Punctuation()])
|
| 51 |
+
trainer = BpeTrainer(
|
| 52 |
+
special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"],
|
| 53 |
+
min_frequency=3,
|
| 54 |
+
vocab_size=60000,
|
| 55 |
+
)
|
| 56 |
+
tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
|
| 57 |
+
tokenizer.save(str(tokenizer_path))
|
| 58 |
+
else:
|
| 59 |
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 60 |
+
return tokenizer
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def preprocess_to_json(dataset_hf, output_json_path, word_limit=77, char_limit=300):
|
| 64 |
+
filtered_data = []
|
| 65 |
+
|
| 66 |
+
for row in dataset_hf:
|
| 67 |
+
if not row.get("translated", False):
|
| 68 |
+
continue
|
| 69 |
+
|
| 70 |
+
en_value = None
|
| 71 |
+
en_conversations = row.get("en_conversations", [])
|
| 72 |
+
for conv in en_conversations:
|
| 73 |
+
if "human" in conv["from"]:
|
| 74 |
+
en_value = conv["value"]
|
| 75 |
+
break
|
| 76 |
+
if en_value is None and len(en_conversations) > 0:
|
| 77 |
+
en_value = en_conversations[0]["value"]
|
| 78 |
+
|
| 79 |
+
hi_value = None
|
| 80 |
+
hi_conversations = row.get("conversations", [])
|
| 81 |
+
for conv in hi_conversations:
|
| 82 |
+
if "human" in conv["from"]:
|
| 83 |
+
hi_value = conv["value"]
|
| 84 |
+
break
|
| 85 |
+
if hi_value is None and len(hi_conversations) > 0:
|
| 86 |
+
hi_value = hi_conversations[0]["value"]
|
| 87 |
+
|
| 88 |
+
if en_value and hi_value:
|
| 89 |
+
if (
|
| 90 |
+
len(en_value.split()) > word_limit
|
| 91 |
+
or len(en_value) > char_limit
|
| 92 |
+
or len(hi_value.split()) > word_limit
|
| 93 |
+
or len(hi_value) > char_limit
|
| 94 |
+
):
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
filtered_data.append({"en_text": en_value, "hi_text": hi_value})
|
| 98 |
+
|
| 99 |
+
with open(output_json_path, "w", encoding="utf-8") as f:
|
| 100 |
+
json.dump(filtered_data, f, ensure_ascii=False, indent=2)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def create_resources():
|
| 105 |
+
config_path = "config.yaml"
|
| 106 |
+
config = load_config(config_path)
|
| 107 |
+
|
| 108 |
+
dataset_json_path = config.get("dataset_path", "data/english2hindi_data.json")
|
| 109 |
+
|
| 110 |
+
if not Path(dataset_json_path).exists():
|
| 111 |
+
print(f"Dataset file {dataset_json_path} not found. Creating it...")
|
| 112 |
+
dataset_hf = load_dataset("BhabhaAI/openhermes-2.5-hindi", split="train")
|
| 113 |
+
preprocess_to_json(dataset_hf, dataset_json_path)
|
| 114 |
+
else:
|
| 115 |
+
print(f"Dataset file {dataset_json_path} already exists. Skipping preprocessing.")
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
with open(dataset_json_path, "r", encoding="utf-8") as f:
|
| 119 |
+
raw_data = json.load(f)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
tokenizer_src = get_or_build_tokenizer(
|
| 124 |
+
config["src_tokenizer_file"], raw_data, config["src_lang"]
|
| 125 |
+
)
|
| 126 |
+
tokenizer_tgt = get_or_build_tokenizer(
|
| 127 |
+
config["tgt_tokenizer_file"], raw_data, config["tgt_lang"]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
test_pt_dataset = English2HindiDatasetTest(config["dataset_path"])
|
| 132 |
+
|
| 133 |
+
print(len(test_pt_dataset))
|
| 134 |
+
|
| 135 |
+
max_len_src = 0
|
| 136 |
+
max_len_tgt = 0
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
seq_len = config["seq_len"]
|
| 140 |
+
|
| 141 |
+
if seq_len == 0:
|
| 142 |
+
print("seq_len is 0, starting process...")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
for item in raw_data:
|
| 146 |
+
src_ids = tokenizer_src.encode(item[config['src_lang']]).ids
|
| 147 |
+
tgt_ids = tokenizer_tgt.encode(item[config['tgt_lang']]).ids
|
| 148 |
+
max_len_src = max(max_len_src, len(src_ids))
|
| 149 |
+
|
| 150 |
+
max_len_tgt = max(max_len_tgt, len(tgt_ids))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
print(max_len_src,max_len_tgt )
|
| 154 |
+
|
| 155 |
+
final_max_len = max(max_len_src, max_len_tgt) + 30
|
| 156 |
+
|
| 157 |
+
config['seq_len'] = final_max_len
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
with open("config.yaml", 'w') as f:
|
| 161 |
+
yaml.safe_dump(config, f, default_flow_style=False)
|
| 162 |
+
|
| 163 |
+
print(f'Updated seq_len to {final_max_len}')
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
print("seq_len is not 0, skipping process.")
|
| 167 |
+
|
| 168 |
+
random.seed(42)
|
| 169 |
+
|
| 170 |
+
train_data, temp_data = train_test_split(raw_data, test_size=0.2, random_state=42)
|
| 171 |
+
test_data, valid_data = train_test_split(temp_data, test_size=0.5, random_state=42)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
print("######################################################")
|
| 175 |
+
|
| 176 |
+
train_dataset = English2HindiDataset(
|
| 177 |
+
train_data,
|
| 178 |
+
tokenizer_src,
|
| 179 |
+
tokenizer_tgt,
|
| 180 |
+
config["src_lang"],
|
| 181 |
+
config["tgt_lang"],
|
| 182 |
+
config["seq_len"],
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
valid_dataset = English2HindiDataset(
|
| 186 |
+
valid_data,
|
| 187 |
+
tokenizer_src,
|
| 188 |
+
tokenizer_tgt,
|
| 189 |
+
config["src_lang"],
|
| 190 |
+
config["tgt_lang"],
|
| 191 |
+
config["seq_len"],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
test_dataset = English2HindiDataset(
|
| 195 |
+
test_data,
|
| 196 |
+
tokenizer_src,
|
| 197 |
+
tokenizer_tgt,
|
| 198 |
+
config["src_lang"],
|
| 199 |
+
config["tgt_lang"],
|
| 200 |
+
config["seq_len"],
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False)
|
| 205 |
+
valid_dataloader = DataLoader(valid_dataset, batch_size=config["batch_size"],shuffle=True)
|
| 206 |
+
test_dataloader = DataLoader(test_dataset, batch_size=config["batch_size"],shuffle=True)
|
| 207 |
+
|
| 208 |
+
return train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt
|