|
|
from torchinfo import summary
|
|
|
from model import build_transformer
|
|
|
from util import create_resources
|
|
|
import yaml
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
import os
|
|
|
import wandb
|
|
|
import matplotlib.pyplot as plt
|
|
|
from matplotlib import font_manager
|
|
|
import re
|
|
|
|
|
|
|
|
|
mangal_font_path = "Mangal.TTf"
|
|
|
devanagari_font = font_manager.FontProperties(fname=mangal_font_path)
|
|
|
|
|
|
|
|
|
class NoamScheduler:
|
|
|
def __init__(self, optimizer, d_model, warmup_steps):
|
|
|
self.optimizer = optimizer
|
|
|
self.d_model = d_model
|
|
|
self.warmup_steps = warmup_steps
|
|
|
self.step_num = 0
|
|
|
|
|
|
def step(self):
|
|
|
self.step_num += 1
|
|
|
lr = self.get_lr()
|
|
|
for param_group in self.optimizer.param_groups:
|
|
|
param_group["lr"] = lr
|
|
|
return lr
|
|
|
|
|
|
def get_lr(self):
|
|
|
step = max(self.step_num, 1)
|
|
|
arg1 = step ** (-0.5)
|
|
|
arg2 = step * (self.warmup_steps ** (-1.5))
|
|
|
return (self.d_model ** (-0.5)) * min(arg1, arg2)
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
model,
|
|
|
optimizer,
|
|
|
scheduler,
|
|
|
criterion,
|
|
|
device,
|
|
|
tokenizer_src,
|
|
|
tokenizer_tgt,
|
|
|
seq_len,
|
|
|
):
|
|
|
self.model = model
|
|
|
self.optimizer = optimizer
|
|
|
self.scheduler = scheduler
|
|
|
self.criterion = criterion
|
|
|
self.device = device
|
|
|
self.tgt_tokenizer = tokenizer_tgt
|
|
|
self.src_tokenizer = tokenizer_src
|
|
|
self.seq_len = seq_len
|
|
|
|
|
|
def train_epoch(self, dataloader):
|
|
|
self.model.train()
|
|
|
torch.cuda.empty_cache()
|
|
|
running_loss = 0.0
|
|
|
total_tokens = 0
|
|
|
progress_bar = tqdm(
|
|
|
enumerate(dataloader), desc="Training", total=len(dataloader)
|
|
|
)
|
|
|
|
|
|
for batch_idx, batch in progress_bar:
|
|
|
|
|
|
encoder_input = batch["encoder_input"].to(self.device)
|
|
|
|
|
|
decoder_input = batch["decoder_input"].to(self.device)
|
|
|
|
|
|
|
|
|
encoder_mask = batch["encoder_mask"].to(self.device)
|
|
|
decoder_mask = batch["decoder_mask"].to(self.device)
|
|
|
|
|
|
encoder_output = self.model.encode(encoder_input, encoder_mask)
|
|
|
decoder_output = self.model.decode(
|
|
|
decoder_input, encoder_output, encoder_mask, decoder_mask
|
|
|
)
|
|
|
projection_output = self.model.project(decoder_output)
|
|
|
|
|
|
label = batch["label"].to(self.device)
|
|
|
|
|
|
loss = self.criterion(
|
|
|
projection_output.view(-1, self.tgt_tokenizer.get_vocab_size()),
|
|
|
label.view(-1),
|
|
|
)
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
self.optimizer.step()
|
|
|
current_lr = self.scheduler.step()
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
|
pad_id = 1
|
|
|
with torch.no_grad():
|
|
|
non_pad = label.ne(pad_id)
|
|
|
num_nonpad_tokens = non_pad.sum().item()
|
|
|
running_loss += loss.item() * num_nonpad_tokens
|
|
|
total_tokens += num_nonpad_tokens
|
|
|
|
|
|
if (batch_idx + 1) % 50 == 0:
|
|
|
wandb.log(
|
|
|
{
|
|
|
"batch_loss": loss.item(),
|
|
|
"learning_rate": current_lr,
|
|
|
"batch": batch_idx + 1,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
epoch_loss = running_loss / total_tokens if total_tokens > 0 else 0.0
|
|
|
return epoch_loss
|
|
|
|
|
|
def save_checkpoint(self, epoch, output_dir):
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
checkpoint = {
|
|
|
"epoch": epoch,
|
|
|
"model_state_dict": self.model.state_dict(),
|
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
|
"scheduler_state": self.scheduler.step_num,
|
|
|
}
|
|
|
torch.save(checkpoint, os.path.join(output_dir, f"model_epoch_{epoch}.pth"))
|
|
|
print(f"Checkpoint saved at epoch {epoch}")
|
|
|
|
|
|
def run(self, train_loader, epochs, output_dir, start_epoch=1):
|
|
|
for epoch in range(start_epoch, epochs + 1):
|
|
|
train_loss = self.train_epoch(train_loader)
|
|
|
current_lr = self.scheduler.get_lr()
|
|
|
|
|
|
wandb.log(
|
|
|
{"epoch": epoch, "train_loss": train_loss, "learning_rate": current_lr}
|
|
|
)
|
|
|
|
|
|
self.save_checkpoint(epoch, output_dir)
|
|
|
|
|
|
|
|
|
def load_latest_checkpoint(model, optimizer, scheduler, model_directory, device):
|
|
|
if not os.path.isdir(model_directory):
|
|
|
return None, 1
|
|
|
checkpoint_files = []
|
|
|
for filename in os.listdir(model_directory):
|
|
|
if filename.endswith(".pth"):
|
|
|
match = re.search(r"model_epoch_(\d+)\.pth", filename)
|
|
|
if match:
|
|
|
epoch = int(match.group(1))
|
|
|
checkpoint_files.append((epoch, filename))
|
|
|
|
|
|
if not checkpoint_files:
|
|
|
return None, 1
|
|
|
|
|
|
|
|
|
latest_epoch, latest_filename = max(checkpoint_files, key=lambda x: x[0])
|
|
|
ckpt_path = os.path.join(model_directory, latest_filename)
|
|
|
ckpt = torch.load(ckpt_path, map_location=device)
|
|
|
|
|
|
model.load_state_dict(ckpt["model_state_dict"])
|
|
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
|
scheduler.step_num = ckpt["scheduler_state"]
|
|
|
start_epoch = ckpt["epoch"] + 1
|
|
|
print(f"Resuming Training from epoch {ckpt['epoch']}")
|
|
|
return ckpt, start_epoch
|
|
|
|
|
|
|
|
|
def main():
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
(
|
|
|
train_dataloader,
|
|
|
valid_dataloader,
|
|
|
test_dataloader,
|
|
|
tokenizer_src,
|
|
|
tokenizer_tgt,
|
|
|
) = create_resources()
|
|
|
src_vocab_size = tokenizer_src.get_vocab_size()
|
|
|
tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
|
|
|
|
|
with open("config.yaml", "r") as file:
|
|
|
config = yaml.safe_load(file)
|
|
|
|
|
|
run = wandb.init(
|
|
|
entity="training-transformers-vast",
|
|
|
project="AttentionTranslate-sai", config=config)
|
|
|
|
|
|
model = build_transformer(
|
|
|
src_vocab_size,
|
|
|
tgt_vocab_size,
|
|
|
config["seq_len"],
|
|
|
config["seq_len"],
|
|
|
config["num_enc_dec_blocks"],
|
|
|
config["num_of_heads"],
|
|
|
config["d_model"],
|
|
|
)
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
wandb.watch(model, log="all")
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss(
|
|
|
ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
|
|
|
).to(device)
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
model.parameters(), lr=config["learning_rate"], betas=(0.9, 0.98), eps=1e-9
|
|
|
)
|
|
|
|
|
|
scheduler = NoamScheduler(optimizer, config["d_model"], config["warmup_steps"])
|
|
|
|
|
|
start_epoch = 1
|
|
|
|
|
|
if config["resume_training"]:
|
|
|
ckpt, start_epoch = load_latest_checkpoint(
|
|
|
model, optimizer, scheduler, config["model_directory"], device
|
|
|
)
|
|
|
|
|
|
if start_epoch == 1:
|
|
|
print("Training from scratch.")
|
|
|
|
|
|
trainer = Trainer(
|
|
|
model,
|
|
|
optimizer,
|
|
|
scheduler,
|
|
|
criterion,
|
|
|
device,
|
|
|
tokenizer_src,
|
|
|
tokenizer_tgt,
|
|
|
config["seq_len"],
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.run(
|
|
|
train_dataloader, config["epochs"], config["model_directory"], start_epoch
|
|
|
)
|
|
|
|
|
|
run.finish()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|