Urfavghost's picture
Added models and code
39a7504
from torchinfo import summary
from model import build_transformer
from util import create_resources
import yaml
import torch
from pathlib import Path
with open("config.yaml", "r") as file:
config = yaml.safe_load(file)
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
src_vocab_size = tokenizer_src.get_vocab_size()
tgt_vocab_size = tokenizer_src.get_vocab_size()
model = build_transformer(
src_vocab_size,
tgt_vocab_size,
config["seq_len"],
config["seq_len"],
config["num_enc_dec_blocks"],
config["num_of_heads"],
config["d_model"]
)
batch_size = config["batch_size"]
num_epochs = config["epochs"] if "epochs" in config else 10
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
criterion = loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"],eps=1e-9)
def save_checkpoint(epoch, model, optimizer, path):
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
}, path)
print(f"Checkpoint saved at epoch {epoch} to {path}")
def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
checkpoint = torch.load(path, map_location=map_location)
model.load_state_dict(checkpoint["model_state_dict"])
if optimizer and "optimizer_state_dict" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint.get("epoch", 0)
print(f"Loaded checkpoint from epoch {start_epoch}")
return start_epoch
def train_one_epoch(device):
model.train()
running_loss = 0.0
def train_model(model):
device = "cuda" if torch.cuda.is_available() else "cpu"
if (device == 'cuda'):
print(f"Device name: {torch.cuda.get_device_name(device.index)}")
print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()