Spaces:
Sleeping
Sleeping
File size: 1,719 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import os
from PIL import Image
from torch.utils.data import Dataset
class ClassificationDataset(Dataset):
def __init__(
self,
root_dir,
class_to_idx,
split="train",
transform=None,
split_ratio=(0.7, 0.15, 0.15)
):
self.transform = transform
self.samples = []
for class_name in sorted(os.listdir(root_dir)):
class_path = os.path.join(
root_dir,
class_name
)
if not os.path.isdir(class_path):
continue
images = sorted(os.listdir(class_path))
total = len(images)
train_end = int(total * split_ratio[0])
val_end = train_end + int(total * split_ratio[1])
if split == "train":
split_images = images[:train_end]
elif split == "val":
split_images = images[train_end:val_end]
else:
split_images = images[val_end:]
label = class_to_idx[class_name]
for image_name in split_images:
image_path = os.path.join(
class_path,
image_name
)
self.samples.append(
(image_path, label)
)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
image_path, label = self.samples[index]
image = Image.open(
image_path
).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label, image_path |