File size: 1,600 Bytes
f4bee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Dataset Utilities - Bridge to new datasets module
"""
import sys
import os
sys.path.insert(0, os.path.abspath('.'))

def load_mnist(data_dir="data/raw/mnist", cache=True, augment=False):
    """Load MNIST dataset (redirects to datasets module)"""
    from datasets.mnist import load_mnist as load_mnist_impl
    return load_mnist_impl(root=data_dir, cache=cache, augment=augment)

def get_dataset_stats(dataset):
    """Get dataset statistics (redirects to appropriate dataset module)"""
    from datasets.mnist import get_mnist_stats
    return get_mnist_stats(dataset)

def create_dataloaders(train_set, test_set, batch_size=64, val_split=0.1):
    """
    Create train/validation/test dataloaders
    
    Args:
        train_set: Training dataset
        test_set: Test dataset
        batch_size: Batch size
        val_split: Fraction of training data for validation
    
    Returns:
        train_loader, val_loader, test_loader
    """
    val_size = int(len(train_set) * val_split)
    train_size = len(train_set) - val_size
    
    train_subset, val_subset = torch.utils.data.random_split(
        train_set, [train_size, val_size]
    )
    
    train_loader = torch.utils.data.DataLoader(
        train_subset, batch_size=batch_size, shuffle=True, num_workers=0
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_subset, batch_size=batch_size, shuffle=False, num_workers=0
    )
    
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=False, num_workers=0
    )
    
    return train_loader, val_loader, test_loader