Spaces:
Build error
Build error
| from .transform import data_transform | |
| from torch.utils.data import Dataset | |
| import os | |
| from PIL import Image | |
| class CustomDataset(Dataset): | |
| def __init__(self, data_folder, transform=None): | |
| self.data_folder = data_folder | |
| self.image_files = os.listdir(data_folder) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.image_files) | |
| def __getitem__(self, idx): | |
| image_name = self.image_files[idx] | |
| label =image_name[:len(image_name)-8] # Extract the label from the filename | |
| image_path = os.path.join(self.data_folder, image_name) | |
| image = Image.open(image_path).convert("RGB") # Ensure images are RGB | |
| if self.transform: | |
| image = self.transform(image) | |
| # print("label: ", label, image) | |
| if label == "circle": | |
| label = 0 | |
| elif label == "square": | |
| label = 1 | |
| elif label == "triangle": | |
| label = 2 | |
| return image, label | |