File size: 1,998 Bytes
c9e0c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.optim as optim
import torch.nn.functional as F
from .models.model import ShapeClassifier

from src.configs.model_config import ModelConfig
from src.data.data_loader import train_loader, num_classes


def train():
    config = ModelConfig().get_config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = ShapeClassifier(num_classes=num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    log_interval = 20
    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if batch_idx % log_interval == 0:
                current_loss = running_loss / log_interval
                print(
                    f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}")
                running_loss = 0.0

                # calculate the accuracy on the test set

                with torch.no_grad():
                    model.eval()
                    correct = 0
                    total = 0
                    for inputs, labels in train_loader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model(inputs)
                        predicted = torch.argmax(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum().item()
                    print(f"Accuracy of the model on the test images: {100 * correct / total} %")
                # save the model
                torch.save(model.state_dict(), "model.pth")