| from torchinfo import summary | |
| from model import build_transformer | |
| from util import create_resources | |
| import yaml | |
| import torch | |
| from pathlib import Path | |
| with open("config.yaml", "r") as file: | |
| config = yaml.safe_load(file) | |
| train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources() | |
| src_vocab_size = tokenizer_src.get_vocab_size() | |
| tgt_vocab_size = tokenizer_src.get_vocab_size() | |
| model = build_transformer( | |
| src_vocab_size, | |
| tgt_vocab_size, | |
| config["seq_len"], | |
| config["seq_len"], | |
| config["num_enc_dec_blocks"], | |
| config["num_of_heads"], | |
| config["d_model"] | |
| ) | |
| batch_size = config["batch_size"] | |
| num_epochs = config["epochs"] if "epochs" in config else 10 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| criterion = loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"],eps=1e-9) | |
| def save_checkpoint(epoch, model, optimizer, path): | |
| torch.save({ | |
| "epoch": epoch, | |
| "model_state_dict": model.state_dict(), | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| }, path) | |
| print(f"Checkpoint saved at epoch {epoch} to {path}") | |
| def load_checkpoint(path, model, optimizer=None, map_location="cpu"): | |
| checkpoint = torch.load(path, map_location=map_location) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| if optimizer and "optimizer_state_dict" in checkpoint: | |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) | |
| start_epoch = checkpoint.get("epoch", 0) | |
| print(f"Loaded checkpoint from epoch {start_epoch}") | |
| return start_epoch | |
| def train_one_epoch(device): | |
| model.train() | |
| running_loss = 0.0 | |
| def train_model(model): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if (device == 'cuda'): | |
| print(f"Device name: {torch.cuda.get_device_name(device.index)}") | |
| print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB") | |
| Path(config["model_folder"]).mkdir(parents=True, exist_ok=True) | |
| train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources() | |