File size: 1,378 Bytes
06c8a6d
 
 
 
 
 
 
 
f6c0b9e
06c8a6d
 
 
 
 
 
 
f6c0b9e
06c8a6d
 
 
 
 
 
 
 
 
 
 
 
 
f6c0b9e
06c8a6d
 
 
 
 
 
 
f6c0b9e
06c8a6d
 
 
c9e0c1d
06c8a6d
f6c0b9e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46


import torch
import torch.optim as optim
import torch.nn.functional as F
from src.models.model import ShapeClassifier

from src.configs.model_config import ModelConfig
from src.data.data_loader import train_loader, num_classes, val_loader
from src.utils.train import train
from src.utils.test import test
from src.utils.wandb import wandb
from src.utils.logs import logging
from src.utils.model import save_model


def main():

    config = ModelConfig().get_config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ShapeClassifier(num_classes=num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    # log models config to wandb
    wandb.config.update(config)

    for epoch in range(config.epochs):
        print(f"Epoch {epoch+1}\n-------------------------------")

        loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
                     optimizer=optimizer)
        test(val_loader, model=model, loss_fn=F.cross_entropy)
        # 3. Log metrics over time to visualize performance
        wandb.log({"loss": loss})

        # save model
        save_model(model, "results/models/last.pth")

        # 4. Log an artifact to W&B
        # wandb.log_artifact("model.pth")
        # model.train()


if __name__ == "__main__":
    logging.info("Training model")
    main()