File size: 3,828 Bytes
0c7049d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

import torch
from torch import nn
from tqdm.auto import tqdm

def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device):
  model.train()
  train_loss, train_acc = 0, 0
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    y_pred = model(X)
    y = y.unsqueeze(dim = 1).float()
    loss = loss_fn(y_pred, y)
    train_loss = train_loss + loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    y_pred_class = torch.sigmoid(y_pred)
    acc = (y_pred_class == y).sum().item() / len(y_pred)
    train_acc = train_acc + acc
  
  train_loss = train_loss / len(dataloader)
  train_acc = train_acc / len(dataloader)

  return train_loss, train_acc

def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device: torch.device):
  model.eval()
  test_loss, test_acc = 0, 0

  with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader):
      X, y = X.to(device), y.to(device)
      y_pred = model(X)
      y = y.unsqueeze(dim = 1).float()
      loss = loss_fn(y_pred, y)
      test_loss = test_loss + loss.item()

      y_pred_class = y_pred.sigmoid()
      acc = (y_pred_class == y).sum().item() / len(y_pred)
      test_acc = test_acc + acc
  test_loss = test_loss / len(dataloader)
  test_acc = test_acc / len(dataloader)
  return test_loss, test_acc
def train(model: torch.nn.Module,
          train_dataloader:  torch.utils.data.DataLoader,
          test_dataloader:  torch.utils.data.DataLoader,
          optimizer:  torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device,
          writer: torch.utils.tensorboard.SummaryWriter):
  results = {"train_loss": [],
             "train_acc": [],
             "test_loss": [],
             "test_acc": []}
  model.to(device)
  # loss_fn = nn.CrossEntropyLoss()
  # optimizer = torch.optim.Adam(model.parameters(),lr = 0.01)
  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(model = model,
                                       dataloader = train_dataloader,
                                       loss_fn = loss_fn,
                                       optimizer = optimizer,
                                       device = device)
    test_loss, test_acc = test_step(model = model,
                                    dataloader = test_dataloader,
                                    loss_fn = loss_fn,
                                    device = device)

    print(
        f"| Epoch: {epoch+1} | "
        f"train_loss: {train_loss:.4f} | "
        f"train_acc: {train_loss:.4f} | "
        f"test_loss: {test_loss:.4f} | "
        f"test_acc: {test_loss:.4f} |"
    )

    results['train_loss'].append(train_loss)
    results['train_acc'].append(train_acc)
    results['test_loss'].append(test_loss)
    results['test_acc'].append(test_acc)

    writer.add_scalars(main_tag="Loss", 
                      tag_scalar_dict={"train_loss": train_loss,
                                    "test_loss": test_loss},
                      global_step=epoch)

    # Add accuracy results to SummaryWriter
    writer.add_scalars(main_tag="Accuracy", 
                    tag_scalar_dict={"train_acc": train_acc,
                                    "test_acc": test_acc}, 
                    global_step=epoch)

    # Track the PyTorch model architecture
    writer.add_graph(model=model, 
                  # Pass in an example input
                  input_to_model=torch.randn(32, 3, 224, 224).to(device))

    # Close the writer
    writer.close()
  return results