File size: 1,719 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from PIL import Image
from torch.utils.data import Dataset


class ClassificationDataset(Dataset):

    def __init__(

        self,

        root_dir,

        class_to_idx,

        split="train",

        transform=None,

        split_ratio=(0.7, 0.15, 0.15)

    ):

        self.transform = transform
        self.samples = []

        for class_name in sorted(os.listdir(root_dir)):
            class_path = os.path.join(
                root_dir,
                class_name
            )

            if not os.path.isdir(class_path):
                continue

            images = sorted(os.listdir(class_path))

            total = len(images)

            train_end = int(total * split_ratio[0])
            val_end = train_end + int(total * split_ratio[1])

            if split == "train":
                split_images = images[:train_end]

            elif split == "val":
                split_images = images[train_end:val_end]

            else:
                split_images = images[val_end:]

            label = class_to_idx[class_name]

            for image_name in split_images:
                image_path = os.path.join(
                    class_path,
                    image_name
                )

                self.samples.append(
                    (image_path, label)
                )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        image_path, label = self.samples[index]
        image = Image.open(
            image_path
        ).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label, image_path