Image Classification
torch
File size: 3,679 Bytes
fe5ea14
 
 
 
 
 
 
 
 
 
e6bc347
 
fe5ea14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6bc347
fe5ea14
 
 
 
 
 
 
 
 
 
e6bc347
 
fe5ea14
 
 
 
 
 
 
 
 
 
e6bc347
fe5ea14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6bc347
 
fe5ea14
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import os
import pandas as pd

def create_indices(labels, mapping):
    return [mapping[label] for label in labels]

def write_to_csv(predicted, actual, probs, write_path, header):

    label_names = ["Non-Damage", "Earthquake", "Fire", "Flood"]

    if header:
        with open(write_path, "w") as file:
            file.write("Predicted,True,Non_Damage_Score,Earthquake_Score,Fire_Score,Flood_Score\n")

    with open(write_path, "a") as file:
        for i in range(len(actual)):
            file.write(
                f"{label_names[actual[i].item()]},"
                f"{label_names[predicted[i].item()]},"
                f"{probs[i, 0].item()},"
                f"{probs[i, 1].item()},"
                f"{probs[i, 2].item()},"
                f"{probs[i, 3].item()}\n"
            )


class ResNet50():

    def __init__(self, num_classes, lr=0.01, momentum=0.9, mapping=None):
        self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        self.num_classes = num_classes
        self.lr = lr
        self.momentum = momentum
        self.num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.num_features, self.num_classes)

        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum)

        self.mapping = mapping

    def train(self, epochs, train_loader):
        loss_over_time = []
        num_epochs = list(range(1, epochs + 1))
        for epoch in range(epochs):
            self.model.train()
            current_loss = 0.0
            for i, data in enumerate(train_loader, 0):
                inputs, labels = data
                self.optimizer.zero_grad()
                outputs = self.model(data[inputs].float())
                indices = create_indices(data[labels], self.mapping)
                target = torch.tensor(indices)
                loss = self.criterion(outputs, target)
                loss.backward()
                self.optimizer.step()
                current_loss += loss.item()
            loss_over_time.append(current_loss / len(train_loader))
            print(f"Epoch: {epoch + 1} \t Loss: {current_loss / len(train_loader)}")

        torch.save({
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict(),
            "epochs": num_epochs,
            "loss": loss_over_time
        }, "model_weights.pth")

    def eval(self, test_loader, write_path=None):
        self.model.eval()
        header = True

        with torch.no_grad():
            correct = 0
            total = 0
            for data in test_loader:
                images, labels = data
                images = data[images].float()
                labels = data[labels]
                # indices = create_indices(labels)
                indices = create_indices(labels, self.mapping)
                labels = torch.tensor(indices)

                outputs = self.model(images)
                _, predicted = torch.max(outputs.data, 1)
                probs = torch.softmax(outputs, dim=1)

                total += len(labels)
                correct += (predicted == labels).sum().item()
                if write_path:
                    write_to_csv(predicted, labels, probs, write_path=write_path, header=header)
                header = False
        
        print(f'Accuracy of the network on the test images: {round(100 * correct / total, 3)}%')