Image Classification
torch
DanielCruz09 commited on
Commit
d8a0b9e
·
verified ·
1 Parent(s): 495cf68

Delete natural_disaster_dataset.py

Browse files
Files changed (1) hide show
  1. natural_disaster_dataset.py +0 -107
natural_disaster_dataset.py DELETED
@@ -1,107 +0,0 @@
1
- from PIL import Image
2
- import os
3
- from torch.utils.data import DataLoader, Dataset
4
- import torch
5
- from skimage import transform
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import torchvision.transforms as transforms
9
- import torchvision.transforms.functional as TF
10
- import streamlit as st
11
-
12
- class NaturalDisasterDataset(Dataset):
13
- """
14
- A custom PyTorch Dataset that contains images of several types of natural disasters,
15
- including earthquakes, fires, and floods.
16
- """
17
- def __init__(self, root:str, transform:any=None) -> None:
18
- """
19
- Creates a custom PyTorch dataset of natural disasters.
20
-
21
- Args:
22
- root (str): A path containing the images.
23
- transform (any): A type of transformation from the scikit-image library.
24
-
25
- Returns:
26
- None
27
- """
28
- self.root = root
29
- self.transform = transform
30
-
31
- self.image_paths = []
32
- self.labels = []
33
-
34
- for label in os.listdir(root):
35
- folder = os.path.join(root, label)
36
- for file in os.listdir(folder):
37
- self.image_paths.append(os.path.join(folder, file))
38
- self.labels.append(label)
39
-
40
- def __len__(self) -> int:
41
- """
42
- Returns the length/size of the dataset.
43
-
44
- Args:
45
- None
46
-
47
- Returns:
48
- length (int): The length of the dataset.
49
- """
50
- return len(self.image_paths)
51
-
52
- def __getitem__(self, idx:int) -> dict:
53
- """
54
- Iterates through the dataset and returns a sample image.
55
-
56
- Args:
57
- idx (int): An index to the dataset.
58
-
59
- Returns:
60
- sample (dict): A dictionary containing the image and its label.
61
- """
62
- img_path = self.image_paths[idx]
63
- label = self.labels[idx]
64
- image = Image.open(img_path).convert("RGB")
65
-
66
- if self.transform:
67
- image = self.transform(image)
68
-
69
- image = transforms.PILToTensor()(image)
70
- sample = {"image": image, "category": label}
71
- return sample
72
-
73
-
74
- def load_sample(self) -> None:
75
- """
76
- Displays four sample images, one of each type of disaster.
77
-
78
- Args:
79
- None
80
-
81
- Returns:
82
- None
83
- """
84
-
85
- categories_needed = {"Normal", "Earthquake", "Fire", "Flood"}
86
- shown = {}
87
-
88
- fig = plt.figure(figsize=(10, 3))
89
-
90
- for sample in self:
91
- category = sample["category"]
92
-
93
- # If we still need this category
94
- if category in categories_needed and category not in shown:
95
- shown[category] = sample["image"]
96
-
97
- # Stop if we have all 4 categories
98
- if len(shown) == len(categories_needed):
99
- break
100
-
101
- for i, (category, image) in enumerate(shown.items()):
102
- ax = plt.subplot(1, 4, i + 1)
103
- ax.imshow(image)
104
- ax.set_title(category)
105
- ax.axis("off")
106
-
107
- plt.tight_layout()