File size: 998 Bytes
c9e0c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from .transform import data_transform
from torch.utils.data import Dataset
import os
from PIL import Image


class CustomDataset(Dataset):
    def __init__(self, data_folder, transform=None):
        self.data_folder = data_folder
        self.image_files = os.listdir(data_folder)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_name = self.image_files[idx]
        label =image_name[:len(image_name)-8]  # Extract the label from the filename
       

        image_path = os.path.join(self.data_folder, image_name)
        image = Image.open(image_path).convert("RGB")  # Ensure images are RGB

        if self.transform:
            image = self.transform(image)
        # print("label: ", label, image)
        if label == "circle":
            label = 0
        elif label == "square":
            label = 1
        elif label == "triangle":
            label = 2

        return image, label