File size: 3,919 Bytes
b39a019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# This is unquantised - for comparision

import ClassUtils
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset
import random
from torchvision import models, transforms
from torch.utils.data import DataLoader
import time

import matplotlib.pyplot as plt
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# These define the model that will be trained
num_classes = 2
batch_size = 256
epochs = 25
learning_rate = 5e-4
train_data_size = 25000
saved_state_dict_path = "MobileNetV3_test.pth"

model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)

model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
model = model.to(device)

dataset = ClassUtils.CrosswalkDataset("zebra_annotations/classification_data")

train_loader = DataLoader(
    Subset(dataset, random.sample(list(range(0, int(len(dataset) * 0.95))), train_data_size)),
      batch_size=batch_size, shuffle=True)
test_loader = DataLoader(
    Subset(dataset, random.sample(list(range(int(len(dataset) * 0.95), len(dataset))), 12)),
      batch_size=batch_size, shuffle=False)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


# Generalised training function that uses a training-testing split defined in the training variables.
# Works best with transfer learning as is shown in the 'MobileNetV3.py' function where it is defined - check it out.
def train_model():
    model.train()
    start_time = time.time()
    for epoch in range(epochs):
        to_do = train_data_size
        running_loss = 0.0
        for inputs, labels in train_loader:
            try:
                inputs, labels = inputs.to(device), labels.to(device)
            except:
                continue
            
            optimizer.zero_grad()
            outputs = torch.sigmoid(model(inputs))
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            to_do -= batch_size

        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}, time {time.time()- start_time}")
        start_time = time.time()

# Do not use to actually evaluate performance, this is for quick checks - 'EvaluatePerformance.py' has actual quantified evaluation tools.
def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            try:
                inputs, labels = inputs.to(device), labels.to(device)
            except:
                continue
            outputs = torch.sigmoid(model(inputs))

            predicted = (outputs/100) > 0.5
            for i in range(len(inputs)):
                    plt.close()
                    plt.imshow(torch.permute(inputs[i], (1, 2, 0)).cpu().detach().numpy())
                    plt.title(f"prediction of {outputs[i].tolist()[0]:.3f}%, {100 * predicted[i].tolist()[0]:.3f}%,\nactual: {labels[i].tolist()}")
                    plt.axis("off")
                    plt.show()
            
            total += labels.size(0)
            # print(predicted, labels)

            for prediction, label in zip(predicted, labels):
                correct += ((prediction[0]>50) == label[0])
    
    print(f"Accuracy: {100 * correct / total}%")



train = True
if __name__ == "__main__":
    if train:
        train_model()
        torch.save(model.state_dict(), "mn3_vs55.pth")
    else:
        state_dictionairy = torch.load(saved_state_dict_path, weights_only=True)
        print(type(state_dictionairy))
        model.load_state_dict(state_dictionairy)

    test_model()

else:
    state_dictionairy = torch.load(saved_state_dict_path, weights_only=True)
    model.load_state_dict(state_dictionairy)
    print(f"Module: [{__name__}] has been loaded")