transformers / src /train.py
shubhexists's picture
Upload folder using huggingface_hub
170fb3e verified
import torch
import torch.nn as nn
from config import get_config, get_weights_file_path
from torch.utils.data import random_split, DataLoader
from datasets import load_dataset
from tokenizers import Tokenizer
from dataset import BilingualDataset, causal_mask
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from model import build_transformer, Transformer
from tqdm import tqdm
import warnings
def greedy_decode(
model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device
):
"""
Inference -
Start with just SOS token in target
Every iteration gives us a new next word which we concatenate into the decoder input and rerun the cycle
Loop till we get EOS
"""
sos_idx = tokenizer_tgt.token_to_id("[SOS]")
eos_idx = tokenizer_tgt.token_to_id("[EOS]")
# Just calculate the encoder input once
encoder_output = model.encode(source, source_mask)
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
while True:
if decoder_input.size(1) == max_len:
break
# run causal_mask
decoder_mask = (
causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
)
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
prob = model.projection(out[:, -1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat(
[
decoder_input,
torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device),
],
dim=1,
)
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(
model,
validation_dataset,
tokenizer_src,
tokenizer_target,
max_len,
device,
print_msg,
num_examples=2,
):
model.eval()
count = 0
console_width = 80
with torch.no_grad():
for batch in validation_dataset:
count += 1
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
# check that the batch size is 1
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
model_out = greedy_decode(
model,
encoder_input,
encoder_mask,
tokenizer_src,
tokenizer_target,
max_len,
device,
)
source_text = batch["src_text"][0]
target_text = batch["tgt_text"][0]
model_out_text = tokenizer_target.decode(model_out.detach().cpu().numpy())
print_msg("-" * console_width)
print_msg(f"{'SOURCE: ':>12}{source_text}")
print_msg(f"{'TARGET: ':>12}{target_text}")
print_msg(f"{'PREDICTED: ':>12}{model_out_text}")
if count == num_examples:
print_msg("-" * console_width)
break
def get_all_sentences(dataset, lang):
for item in dataset:
yield item["translation"][lang]
def get_or_build_tokenizer(config, dataset, lang):
"""
This takes in the dataset and splits all the sentences into tokens
Adds four extra tokens to the token list -> "[UNK]", "[SOS]", "[EOS]" and "[PAD]"
min frequency for each word to be in our tokenizer is 2 i.e. each word should appear alteast 2 times
to be included
"""
tokenizer_path = Path(config["tokenizer_file"].format(lang))
if not Path.exists(tokenizer_path):
tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(
special_tokens=["[UNK]", "[SOS]", "[EOS]", "[PAD]"], min_frequency=2
)
tokenizer.train_from_iterator(get_all_sentences(dataset, lang), trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
def get_dataset(config):
dataset_raw = load_dataset(
"opus_books", f"{config['lang_src']}-{config['lang_target']}", split="train"
)
tokenizer_src = get_or_build_tokenizer(config, dataset_raw, config["lang_src"])
tokenizer_target = get_or_build_tokenizer(
config, dataset_raw, config["lang_target"]
)
# Split the dataset into training and validation
train_dataset_size = int(0.9 * len(dataset_raw))
validation_dataset_size = len(dataset_raw) - train_dataset_size
train_dataset_raw, validation_dataset_raw = random_split(
dataset_raw, [train_dataset_size, validation_dataset_size]
)
# Initialize the classes
train_dataset = BilingualDataset(
train_dataset_raw,
tokenizer_src,
tokenizer_target,
config["lang_src"],
config["lang_target"],
config["seq_len"],
)
validation_dataset = BilingualDataset(
validation_dataset_raw,
tokenizer_src,
tokenizer_target,
config["lang_src"],
config["lang_target"],
config["seq_len"],
)
# Calculate the max_len
max_len_src = 0
max_len_target = 0
for item in dataset_raw:
src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
target_ids = tokenizer_src.encode(
item["translation"][config["lang_target"]]
).ids
max_len_src = max(len(src_ids), max_len_src)
max_len_target = max(len(target_ids), max_len_target)
train_dataloader = DataLoader(
train_dataset, batch_size=config["batch_size"], shuffle=True
)
validation_dataloader = DataLoader(validation_dataset, batch_size=1, shuffle=True)
return train_dataloader, validation_dataloader, tokenizer_src, tokenizer_target
def get_model(config, vocab_src_len, vocab_target_length) -> Transformer:
model = build_transformer(
vocab_src_len,
vocab_target_length,
config["seq_len"],
config["seq_len"],
d_model=config["d_model"],
N=4,
head=4,
dropout=0.1,
d_ff=256,
)
return model
def train_model(config) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
train_dataloader, validation_dataloader, tokenizer_src, tokenizer_target = (
get_dataset(config)
)
model = get_model(
config, tokenizer_src.get_vocab_size(), tokenizer_target.get_vocab_size()
).to(device)
# Adam's optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
initial_epoch = 0
global_step = 0
if config["preload"]:
model_filename = get_weights_file_path(config, config["preload"])
state = torch.load(model_filename)
initial_epoch = state["epoch"] + 1
optimizer.load_state_dict(state["optimizer_state_dict"])
global_step = state["global_step"]
# Loss functions
loss_fn = nn.CrossEntropyLoss(
ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
).to(device)
for epoch in range(initial_epoch, config["num_epochs"]):
batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch : {epoch:02d}")
for batch in batch_iterator:
model.train()
encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
decoder_input = batch["decoder_input"].to(device) # (B, seq_len)
encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, seq_len)
decoder_mask = batch["decoder_mask"].to(device) # (B, 1, seq_len, seq_len)
encoder_output = model.encode(
encoder_input, encoder_mask
) # (B, seq_len, d_model)
decoder_output = model.decode(
encoder_output, encoder_mask, decoder_input, decoder_mask
) # (B, seq_len, d_model)
proj_output = model.projection(decoder_output) # (B, seq_len, vocab_size)
label = batch["label"].to(device) # (B, seq_len)
# Compare the expected output with the label
loss = loss_fn(
proj_output.view(-1, tokenizer_target.get_vocab_size()), label.view(-1)
)
batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
# Back Propogation
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
global_step += 1
# Inference after each epoch to see the results
run_validation(
model,
validation_dataloader,
tokenizer_src,
tokenizer_target,
config["seq_len"],
device,
lambda msg: batch_iterator.write(msg),
)
model_filename = get_weights_file_path(config, f"{epoch:02d}")
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"global_step": global_step,
},
model_filename,
)
if __name__ == "__main__":
warnings.filterwarnings("ignore")
config = get_config()
train_model(config)