|
|
import os |
|
|
import torch |
|
|
from torchvision.datasets import Caltech101 |
|
|
from torchvision import transforms |
|
|
from torch.utils.data import Dataset |
|
|
from sklearn.model_selection import train_test_split |
|
|
from PIL import Image |
|
|
from typing import Optional, Callable, Tuple |
|
|
|
|
|
class CustomCaltech101(Dataset): |
|
|
def __init__(self, root: str, train: bool = True, transform: Optional[Callable] = None, |
|
|
target_transform: Optional[Callable] = None, download: bool = False, |
|
|
test_size: float = 0.2, random_state: int = 42): |
|
|
|
|
|
|
|
|
self.dataset = Caltech101(root, download=download) |
|
|
self.data = [] |
|
|
self.targets = [] |
|
|
self.transform = transform |
|
|
self.target_transform = target_transform |
|
|
|
|
|
for idx, (image, target) in enumerate(self.dataset): |
|
|
self.data.append(image.convert('RGB')) |
|
|
self.targets.append(target) |
|
|
|
|
|
|
|
|
self.train = train |
|
|
train_data, test_data, train_targets, test_targets = train_test_split( |
|
|
self.data, self.targets, test_size=test_size, stratify=self.targets, random_state=random_state |
|
|
) |
|
|
|
|
|
if self.train: |
|
|
self.data, self.targets = train_data, train_targets |
|
|
else: |
|
|
self.data, self.targets = test_data, test_targets |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]: |
|
|
img = self.data[index] |
|
|
target = self.targets[index] |
|
|
|
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
root_dir = "/data/datasets" |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
]) |
|
|
|
|
|
|
|
|
train_dataset = CustomCaltech101(root=root_dir, train=True, transform=transform, download=True) |
|
|
|
|
|
|
|
|
test_dataset = CustomCaltech101(root=root_dir, train=False, transform=transform, download=True) |
|
|
|
|
|
|
|
|
print(f"Number of training samples: {len(train_dataset)}") |
|
|
print(f"Number of testing samples: {len(test_dataset)}") |