Image Classification
torch
DanielCruz09 commited on
Commit
da82593
·
1 Parent(s): 72ea5ef

Added data loader

Browse files
Files changed (1) hide show
  1. natural_disaster_dataset.py +107 -0
natural_disaster_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from skimage import transform
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torchvision.transforms as transforms
9
+ import torchvision.transforms.functional as TF
10
+ import streamlit as st
11
+
12
+ class NaturalDisasterDataset(Dataset):
13
+ """
14
+ A custom PyTorch Dataset that contains images of several types of natural disasters,
15
+ including earthquakes, fires, and floods.
16
+ """
17
+ def __init__(self, root:str, transform:any=None) -> None:
18
+ """
19
+ Creates a custom PyTorch dataset of natural disasters.
20
+
21
+ Args:
22
+ root (str): A path containing the images.
23
+ transform (any): A type of transformation from the scikit-image library.
24
+
25
+ Returns:
26
+ None
27
+ """
28
+ self.root = root
29
+ self.transform = transform
30
+
31
+ self.image_paths = []
32
+ self.labels = []
33
+
34
+ for label in os.listdir(root):
35
+ folder = os.path.join(root, label)
36
+ for file in os.listdir(folder):
37
+ self.image_paths.append(os.path.join(folder, file))
38
+ self.labels.append(label)
39
+
40
+ def __len__(self) -> int:
41
+ """
42
+ Returns the length/size of the dataset.
43
+
44
+ Args:
45
+ None
46
+
47
+ Returns:
48
+ length (int): The length of the dataset.
49
+ """
50
+ return len(self.image_paths)
51
+
52
+ def __getitem__(self, idx:int) -> dict:
53
+ """
54
+ Iterates through the dataset and returns a sample image.
55
+
56
+ Args:
57
+ idx (int): An index to the dataset.
58
+
59
+ Returns:
60
+ sample (dict): A dictionary containing the image and its label.
61
+ """
62
+ img_path = self.image_paths[idx]
63
+ label = self.labels[idx]
64
+ image = Image.open(img_path).convert("RGB")
65
+
66
+ if self.transform:
67
+ image = self.transform(image)
68
+
69
+ image = transforms.PILToTensor()(image)
70
+ sample = {"image": image, "category": label}
71
+ return sample
72
+
73
+
74
+ def load_sample(self) -> None:
75
+ """
76
+ Displays four sample images, one of each type of disaster.
77
+
78
+ Args:
79
+ None
80
+
81
+ Returns:
82
+ None
83
+ """
84
+
85
+ categories_needed = {"Normal", "Earthquake", "Fire", "Flood"}
86
+ shown = {}
87
+
88
+ fig = plt.figure(figsize=(10, 3))
89
+
90
+ for sample in self:
91
+ category = sample["category"]
92
+
93
+ # If we still need this category
94
+ if category in categories_needed and category not in shown:
95
+ shown[category] = sample["image"]
96
+
97
+ # Stop if we have all 4 categories
98
+ if len(shown) == len(categories_needed):
99
+ break
100
+
101
+ for i, (category, image) in enumerate(shown.items()):
102
+ ax = plt.subplot(1, 4, i + 1)
103
+ ax.imshow(image)
104
+ ax.set_title(category)
105
+ ax.axis("off")
106
+
107
+ plt.tight_layout()