import torch from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms import os import pandas as pd from PIL import Image class KaggleCIFAR10Dataset(Dataset): """Custom dataset for CIFAR-10 in train/test + CSV format.""" def __init__(self, data_dir, is_train=True, transform=None): self.data_dir = data_dir self.is_train = is_train self.transform = transform # Class name mapping (same for train and test) self.class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] self.label_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Common root for labels and image folders labels_root = os.path.join(data_dir, 'raw') if is_train: labels_path = os.path.join(labels_root, 'trainLabels.csv') self.image_dir = os.path.join(labels_root, 'train') else: labels_path = os.path.join(labels_root, 'testLabels.csv') self.image_dir = os.path.join(labels_root, 'test') self.labels_df = pd.read_csv(labels_path) def __len__(self): return len(self.labels_df) def __getitem__(self, idx): row = self.labels_df.iloc[idx] image_id = row['id'] label_name = row['label'] image_path = os.path.join(self.image_dir, f"{image_id}.png") image = Image.open(image_path).convert('RGB') label = self.label_to_idx[label_name] if self.transform: image = self.transform(image) return image, label def get_cifar10_loaders(data_dir='./data', batch_size=64, num_workers=4, val_split=0.1): """ Load CIFAR-10 from custom directory layout with augmentation and splits. Args: data_dir: Directory containing 'raw/train', 'raw/test', and both CSV files batch_size: Batch size for loaders num_workers: Worker processes for DataLoader val_split: Fraction of training data used for validation Returns: (train_loader, val_loader, test_loader) """ CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) CIFAR10_STD = (0.2023, 0.1994, 0.2010) train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD) ]) eval_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD) ]) # Full training dataset full_train_dataset = KaggleCIFAR10Dataset( data_dir=data_dir, is_train=True, transform=train_transform ) # Split into train and validation train_size = int((1 - val_split) * len(full_train_dataset)) val_size = len(full_train_dataset) - train_size train_subset, val_subset = random_split( full_train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42) ) # Use eval transforms for validation set val_dataset = KaggleCIFAR10Dataset( data_dir=data_dir, is_train=True, transform=eval_transform ) val_subset.dataset = val_dataset # Load labeled test dataset test_dataset = KaggleCIFAR10Dataset( data_dir=data_dir, is_train=False, transform=eval_transform ) # DataLoaders train_loader = DataLoader( train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=True if num_workers > 0 else False ) val_loader = DataLoader( val_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=True if num_workers > 0 else False ) test_loader = DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available(), persistent_workers=True if num_workers > 0 else False ) print(f"📊 Dataset loaded successfully!") print(f" Training samples: {len(train_subset)}") print(f" Validation samples: {len(val_subset)}") print(f" Test samples: {len(test_dataset)}") return train_loader, val_loader, test_loader def get_cifar10_info(): """Return CIFAR-10 dataset information.""" return { 'num_classes': 10, 'class_names': ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'], 'input_shape': (3, 32, 32), 'mean': (0.4914, 0.4822, 0.4465), 'std': (0.2023, 0.1994, 0.2010) }