Image Classification
torch
DanielCruz09 commited on
Commit
fe5ea14
·
1 Parent(s): da82593

Added preprocessing script

Browse files
models/model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cbeb72bb5c303d0cf32c7f4003990b4de68460b342d0d1cecb3e147c1f1ebd6
3
+ size 94380387
models/natural_disaster_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from skimage import transform
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as TF
10
+ import streamlit as st
11
+
12
+ class NaturalDisasterDataset(Dataset):
13
+ """
14
+ A custom PyTorch Dataset that contains images of several types of natural disasters,
15
+ including earthquakes, fires, and floods.
16
+ """
17
+ def __init__(self, root:str, transform:any=None) -> None:
18
+ """
19
+ Creates a custom PyTorch dataset of natural disasters.
20
+
21
+ Args:
22
+ root (str): A path containing the images.
23
+ transform (any): A type of transformation from the scikit-image library.
24
+
25
+ Returns:
26
+ None
27
+ """
28
+ self.root = root
29
+ self.transform = transform
30
+
31
+ self.image_paths = []
32
+ self.labels = []
33
+
34
+ for label in os.listdir(root):
35
+ folder = os.path.join(root, label)
36
+ for file in os.listdir(folder):
37
+ self.image_paths.append(os.path.join(folder, file))
38
+ self.labels.append(label)
39
+
40
+ def __len__(self) -> int:
41
+ """
42
+ Returns the length/size of the dataset.
43
+
44
+ Args:
45
+ None
46
+
47
+ Returns:
48
+ length (int): The length of the dataset.
49
+ """
50
+ return len(self.image_paths)
51
+
52
+ def __getitem__(self, idx:int) -> dict:
53
+ """
54
+ Iterates through the dataset and returns a sample image.
55
+
56
+ Args:
57
+ idx (int): An index to the dataset.
58
+
59
+ Returns:
60
+ sample (dict): A dictionary containing the image and its label.
61
+ """
62
+ img_path = self.image_paths[idx]
63
+ label = self.labels[idx]
64
+ image = Image.open(img_path).convert("RGB")
65
+
66
+ if self.transform:
67
+ image = self.transform(image)
68
+
69
+ image = transforms.PILToTensor()(image)
70
+ sample = {"image": image, "category": label}
71
+ return sample
72
+
73
+
74
+ def load_sample(self) -> None:
75
+ """
76
+ Displays four sample images, one of each type of disaster.
77
+
78
+ Args:
79
+ None
80
+
81
+ Returns:
82
+ None
83
+ """
84
+
85
+ categories_needed = {"Normal", "Earthquake", "Fire", "Flood"}
86
+ shown = {}
87
+
88
+ fig = plt.figure(figsize=(10, 3))
89
+
90
+ for sample in self:
91
+ category = sample["category"]
92
+
93
+ # If we still need this category
94
+ if category in categories_needed and category not in shown:
95
+ shown[category] = sample["image"]
96
+
97
+ # Stop if we have all 4 categories
98
+ if len(shown) == len(categories_needed):
99
+ break
100
+
101
+ for i, (category, image) in enumerate(shown.items()):
102
+ ax = plt.subplot(1, 4, i + 1)
103
+ ax.imshow(image)
104
+ ax.set_title(category)
105
+ ax.axis("off")
106
+
107
+ plt.tight_layout()
models/resnet50.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.models as models
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms as transforms
8
+ import os
9
+ import pandas as pd
10
+
11
+
12
+ def create_indices(labels):
13
+ mapping = {
14
+ "Non_Damage": 0,
15
+ "Land_Disaster": 1,
16
+ "Fire_Disaster": 2,
17
+ "Water_Disaster": 3
18
+ }
19
+
20
+ indices = list(mapping[category] for category in labels)
21
+ return indices
22
+
23
+ def write_to_csv(predicted, actual, probs, write_path, header):
24
+
25
+ label_names = ["Non-Damage", "Earthquake", "Fire", "Flood"]
26
+
27
+ if header:
28
+ with open(write_path, "w") as file:
29
+ file.write("Predicted,True,Non_Damage_Score,Earthquake_Score,Fire_Score,Flood_Score\n")
30
+
31
+ with open(write_path, "a") as file:
32
+ for i in range(len(actual)):
33
+ file.write(
34
+ f"{label_names[actual[i].item()]},"
35
+ f"{label_names[predicted[i].item()]},"
36
+ f"{probs[i, 0].item()},"
37
+ f"{probs[i, 1].item()},"
38
+ f"{probs[i, 2].item()},"
39
+ f"{probs[i, 3].item()}\n"
40
+ )
41
+
42
+
43
+ class ResNet50():
44
+
45
+ def __init__(self, num_classes, lr=0.01, momentum=0.9):
46
+ self.model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
47
+ self.num_classes = num_classes
48
+ self.lr = lr
49
+ self.momentum = momentum
50
+ self.num_features = self.model.fc.in_features
51
+ self.model.fc = nn.Linear(self.num_features, self.num_classes)
52
+
53
+ self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
54
+ self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum)
55
+
56
+ def train(self, epochs, train_loader):
57
+ loss_over_time = []
58
+ num_epochs = list(range(1, epochs + 1))
59
+ for epoch in range(epochs):
60
+ self.model.train()
61
+ current_loss = 0.0
62
+ for i, data in enumerate(train_loader, 0):
63
+ inputs, labels = data
64
+ self.optimizer.zero_grad()
65
+ outputs = self.model(data[inputs].float())
66
+ indices = create_indices(data[labels])
67
+ target = torch.tensor(indices)
68
+ loss = self.criterion(outputs, target)
69
+ loss.backward()
70
+ self.optimizer.step()
71
+ current_loss += loss.item()
72
+ loss_over_time.append(current_loss / len(train_loader))
73
+ print(f"Epoch: {epoch + 1} \t Loss: {current_loss / len(train_loader)}")
74
+
75
+ torch.save({
76
+ "model_state_dict": self.model.state_dict(),
77
+ "optimizer_state_dict": self.optimizer.state_dict(),
78
+ "epochs": num_epochs,
79
+ "loss": loss_over_time
80
+ }, "model_weights.pth")
81
+
82
+ data = {
83
+ "Epochs": num_epochs,
84
+ "Loss": loss_over_time
85
+ }
86
+ data = pd.DataFrame(data=data)
87
+ data.to_csv("results/model_progress.csv", index=False)
88
+
89
+ def eval(self, test_loader, write_path=None):
90
+ self.model.eval()
91
+ header = True
92
+
93
+ with torch.no_grad():
94
+ correct = 0
95
+ total = 0
96
+ for data in test_loader:
97
+ images, labels = data
98
+ images = data[images].float()
99
+ labels = data[labels]
100
+ indices = create_indices(labels)
101
+ labels = torch.tensor(indices)
102
+
103
+ outputs = self.model(images)
104
+ _, predicted = torch.max(outputs.data, 1)
105
+ probs = torch.softmax(outputs, dim=1)
106
+
107
+ total += len(labels)
108
+ correct += (predicted == labels).sum().item()
109
+ if write_path:
110
+ write_to_csv(predicted, labels, probs, write_path=write_path, header=header)
111
+ header = False
112
+
113
+ print(f'Accuracy of the network on the test images: {round(100 * correct / total, 3)}%')
models/save_model_progress.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from resnet50 import ResNet50
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+
6
+ n_categories = 4
7
+ model = ResNet50(n_categories)
8
+ weights = torch.load("model_weights.pth")
9
+ model.model.load_state_dict(weights["model_state_dict"])
10
+ epochs = weights["epochs"]
11
+ loss = weights["loss"]
12
+
13
+ data = {
14
+ "Epochs": epochs,
15
+ "Loss": loss
16
+ }
17
+
18
+ data = pd.DataFrame(data)
19
+ data.to_csv("model_progress.csv", index=False)
models/upload_model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from resnet50 import ResNet50
3
+ from huggingface_hub import HfApi
4
+
5
+ api = HfApi()
6
+
7
+ model = ResNet50(num_classes=4)
8
+ torch.save(model.model.state_dict(), "model_weights.pth")
9
+
10
+ api.upload_file(
11
+ path_or_fileobj="model_weights.pth",
12
+ path_in_repo="model_weights.pth",
13
+ repo_id="DanielCruz09/disaster-image-classifier",
14
+ repo_type="model"
15
+ )