File size: 6,537 Bytes
8f24287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
Dataset loading and subsetting for the ML Training Optimizer environment.

Provides deterministic subsets of standard datasets:
- MNIST: 5k samples (4k train / 1k val)
- FashionMNIST: 8k samples (6.5k train / 1.5k val)
- CIFAR-10: 10k samples (8k train / 2k val)

Small subsets intentionally make overfitting a real challenge.
"""

import os
from typing import Tuple

import torch
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torchvision.transforms as transforms


# Cache datasets inside the container
DATA_DIR = os.environ.get("DATA_DIR", "/tmp/ml_trainer_data")


def _get_mnist_transforms(augment: bool = False, aug_strength: float = 0.5):
    """Get transforms for MNIST (28×28 grayscale)."""
    base = [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    if augment:
        # Scale rotation and translation by augmentation strength
        max_rot = int(15 * aug_strength)
        max_translate = 0.1 * aug_strength
        aug = [
            transforms.RandomRotation(max_rot),
            transforms.RandomAffine(0, translate=(max_translate, max_translate)),
        ]
        return transforms.Compose(aug + base)
    return transforms.Compose(base)


def _get_fashion_transforms(augment: bool = False, aug_strength: float = 0.5):
    """Get transforms for FashionMNIST (28×28 grayscale)."""
    base = [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
    if augment:
        max_rot = int(10 * aug_strength)
        aug = [
            transforms.RandomHorizontalFlip(p=0.5 * aug_strength),
            transforms.RandomRotation(max_rot),
        ]
        return transforms.Compose(aug + base)
    return transforms.Compose(base)


def _get_cifar_transforms(augment: bool = False, aug_strength: float = 0.5):
    """Get transforms for CIFAR-10 (32×32 RGB)."""
    base = [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ]
    if augment:
        crop_pad = max(1, int(4 * aug_strength))
        aug = [
            transforms.RandomCrop(32, padding=crop_pad),
            transforms.RandomHorizontalFlip(p=0.5 * aug_strength),
        ]
        if aug_strength > 0.5:
            aug.append(
                transforms.ColorJitter(
                    brightness=0.2 * aug_strength,
                    contrast=0.2 * aug_strength,
                )
            )
        return transforms.Compose(aug + base)
    return transforms.Compose(base)


def load_dataset(
    dataset_name: str,
    seed: int = 42,
    augment: bool = False,
    aug_strength: float = 0.5,
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, int, int]:
    """
    Load a dataset and return deterministic train/val subsets.

    Args:
        dataset_name: One of 'mnist', 'fashion_mnist', 'cifar10'
        seed: Random seed for reproducible subsetting
        augment: Whether to apply data augmentation to training set
        aug_strength: Augmentation intensity (0.0 to 1.0)

    Returns:
        (train_dataset, val_dataset, num_train, num_val)
    """
    generator = torch.Generator().manual_seed(seed)

    if dataset_name == "mnist":
        transform_train = _get_mnist_transforms(augment, aug_strength)
        transform_val = _get_mnist_transforms(augment=False)
        full_train = torchvision.datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform_train)
        full_val = torchvision.datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform_val)
        total_subset = 5000
        train_size, val_size = 4000, 1000

    elif dataset_name == "fashion_mnist":
        transform_train = _get_fashion_transforms(augment, aug_strength)
        transform_val = _get_fashion_transforms(augment=False)
        full_train = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform_train)
        full_val = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform_val)
        total_subset = 8000
        train_size, val_size = 6500, 1500

    elif dataset_name == "cifar10":
        transform_train = _get_cifar_transforms(augment, aug_strength)
        transform_val = _get_cifar_transforms(augment=False)
        full_train = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True, transform=transform_train)
        full_val = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True, transform=transform_val)
        total_subset = 10000
        train_size, val_size = 8000, 2000

    else:
        raise ValueError(f"Unknown dataset: {dataset_name}. Choose from: mnist, fashion_mnist, cifar10")

    # Create deterministic subset indices
    all_indices = torch.randperm(len(full_train), generator=generator)[:total_subset].tolist()
    train_indices = all_indices[:train_size]
    val_indices = all_indices[train_size:train_size + val_size]

    train_dataset = Subset(full_train, train_indices)
    val_dataset = Subset(full_val, val_indices)

    return train_dataset, val_dataset, train_size, val_size


def create_dataloaders(
    dataset_name: str,
    batch_size: int = 64,
    seed: int = 42,
    augment: bool = False,
    aug_strength: float = 0.5,
) -> Tuple[DataLoader, DataLoader, int, int]:
    """
    Create DataLoaders for a dataset.

    Returns:
        (train_loader, val_loader, num_train, num_val)
    """
    train_ds, val_ds, n_train, n_val = load_dataset(
        dataset_name, seed=seed, augment=augment, aug_strength=aug_strength
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # Keep it simple for 2 vCPU
        pin_memory=False,
        generator=torch.Generator().manual_seed(seed),
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size * 2,  # Larger batch for eval (no grads)
        shuffle=False,
        num_workers=0,
    )

    return train_loader, val_loader, n_train, n_val


def download_all_datasets():
    """Pre-download all datasets. Called during Docker build."""
    print("Downloading MNIST...")
    torchvision.datasets.MNIST(DATA_DIR, train=True, download=True)
    print("Downloading FashionMNIST...")
    torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True)
    print("Downloading CIFAR-10...")
    torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True)
    print("All datasets downloaded.")


if __name__ == "__main__":
    download_all_datasets()