File size: 2,382 Bytes
39a7504
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from torchinfo import summary
from model import build_transformer
from util import create_resources
import yaml
import torch
from pathlib import Path





with open("config.yaml", "r") as file:
    config = yaml.safe_load(file)


train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
src_vocab_size = tokenizer_src.get_vocab_size()
tgt_vocab_size = tokenizer_src.get_vocab_size()

model = build_transformer(
    src_vocab_size,
    tgt_vocab_size,
    config["seq_len"],
    config["seq_len"],
    config["num_enc_dec_blocks"],
    config["num_of_heads"],
    config["d_model"]
)

batch_size = config["batch_size"]
num_epochs = config["epochs"] if "epochs" in config else 10


device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device) 
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"],eps=1e-9)


def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    }, path)
    print(f"Checkpoint saved at epoch {epoch} to {path}")


def load_checkpoint(path, model, optimizer=None, map_location="cpu"):
    checkpoint = torch.load(path, map_location=map_location)
    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer and "optimizer_state_dict" in checkpoint:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint.get("epoch", 0)
    print(f"Loaded checkpoint from epoch {start_epoch}")
    return start_epoch


def train_one_epoch(device):
    model.train()
    running_loss = 0.0




def train_model(model):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    
    Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
    train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()