Ttius's picture
Upload 192 files
998bb30 verified
import os
import torch
from torchvision.datasets import Caltech101
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from PIL import Image
from typing import Optional, Callable, Tuple
class CustomCaltech101(Dataset):
def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, download: bool = False,
test_size: float = 0.2, random_state: int = 42):
# Preload images and labels
self.dataset = Caltech101(root, download=download)
self.data = []
self.targets = []
self.transform = transform
self.target_transform = target_transform
for idx, (image, target) in enumerate(self.dataset):
self.data.append(image.convert('RGB'))
self.targets.append(target)
# Split the data into train and test sets
self.train = train
train_data, test_data, train_targets, test_targets = train_test_split(
self.data, self.targets, test_size=test_size, stratify=self.targets, random_state=random_state
)
if self.train:
self.data, self.targets = train_data, train_targets
else:
self.data, self.targets = test_data, test_targets
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
img = self.data[index]
target = self.targets[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
# Example usage
if __name__ == "__main__":
# Define the root directory where the dataset will be stored
root_dir = "/data/datasets"
# Define the transform to apply to the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Create an instance of the custom dataset for training
train_dataset = CustomCaltech101(root=root_dir, train=True, transform=transform, download=True)
# Create an instance of the custom dataset for testing
test_dataset = CustomCaltech101(root=root_dir, train=False, transform=transform, download=True)
# Print some information
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")