""" Dataset Utilities - Bridge to new datasets module """ import sys import os sys.path.insert(0, os.path.abspath('.')) def load_mnist(data_dir="data/raw/mnist", cache=True, augment=False): """Load MNIST dataset (redirects to datasets module)""" from datasets.mnist import load_mnist as load_mnist_impl return load_mnist_impl(root=data_dir, cache=cache, augment=augment) def get_dataset_stats(dataset): """Get dataset statistics (redirects to appropriate dataset module)""" from datasets.mnist import get_mnist_stats return get_mnist_stats(dataset) def create_dataloaders(train_set, test_set, batch_size=64, val_split=0.1): """ Create train/validation/test dataloaders Args: train_set: Training dataset test_set: Test dataset batch_size: Batch size val_split: Fraction of training data for validation Returns: train_loader, val_loader, test_loader """ val_size = int(len(train_set) * val_split) train_size = len(train_set) - val_size train_subset, val_subset = torch.utils.data.random_split( train_set, [train_size, val_size] ) train_loader = torch.utils.data.DataLoader( train_subset, batch_size=batch_size, shuffle=True, num_workers=0 ) val_loader = torch.utils.data.DataLoader( val_subset, batch_size=batch_size, shuffle=False, num_workers=0 ) test_loader = torch.utils.data.DataLoader( test_set, batch_size=batch_size, shuffle=False, num_workers=0 ) return train_loader, val_loader, test_loader