npv2k1's picture
feat: ShapeClassifier
c9e0c1d verified
raw
history blame contribute delete
998 Bytes
from .transform import data_transform
from torch.utils.data import Dataset
import os
from PIL import Image
class CustomDataset(Dataset):
def __init__(self, data_folder, transform=None):
self.data_folder = data_folder
self.image_files = os.listdir(data_folder)
self.transform = transform
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_name = self.image_files[idx]
label =image_name[:len(image_name)-8] # Extract the label from the filename
image_path = os.path.join(self.data_folder, image_name)
image = Image.open(image_path).convert("RGB") # Ensure images are RGB
if self.transform:
image = self.transform(image)
# print("label: ", label, image)
if label == "circle":
label = 0
elif label == "square":
label = 1
elif label == "triangle":
label = 2
return image, label