ml_trainer_env / server /datasets.py
BART-ender's picture
Upload folder using huggingface_hub
8f24287 verified
"""
Dataset loading and subsetting for the ML Training Optimizer environment.
Provides deterministic subsets of standard datasets:
- MNIST: 5k samples (4k train / 1k val)
- FashionMNIST: 8k samples (6.5k train / 1.5k val)
- CIFAR-10: 10k samples (8k train / 2k val)
Small subsets intentionally make overfitting a real challenge.
"""
import os
from typing import Tuple
import torch
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torchvision.transforms as transforms
# Cache datasets inside the container
DATA_DIR = os.environ.get("DATA_DIR", "/tmp/ml_trainer_data")
def _get_mnist_transforms(augment: bool = False, aug_strength: float = 0.5):
"""Get transforms for MNIST (28×28 grayscale)."""
base = [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
if augment:
# Scale rotation and translation by augmentation strength
max_rot = int(15 * aug_strength)
max_translate = 0.1 * aug_strength
aug = [
transforms.RandomRotation(max_rot),
transforms.RandomAffine(0, translate=(max_translate, max_translate)),
]
return transforms.Compose(aug + base)
return transforms.Compose(base)
def _get_fashion_transforms(augment: bool = False, aug_strength: float = 0.5):
"""Get transforms for FashionMNIST (28×28 grayscale)."""
base = [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
if augment:
max_rot = int(10 * aug_strength)
aug = [
transforms.RandomHorizontalFlip(p=0.5 * aug_strength),
transforms.RandomRotation(max_rot),
]
return transforms.Compose(aug + base)
return transforms.Compose(base)
def _get_cifar_transforms(augment: bool = False, aug_strength: float = 0.5):
"""Get transforms for CIFAR-10 (32×32 RGB)."""
base = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
]
if augment:
crop_pad = max(1, int(4 * aug_strength))
aug = [
transforms.RandomCrop(32, padding=crop_pad),
transforms.RandomHorizontalFlip(p=0.5 * aug_strength),
]
if aug_strength > 0.5:
aug.append(
transforms.ColorJitter(
brightness=0.2 * aug_strength,
contrast=0.2 * aug_strength,
)
)
return transforms.Compose(aug + base)
return transforms.Compose(base)
def load_dataset(
dataset_name: str,
seed: int = 42,
augment: bool = False,
aug_strength: float = 0.5,
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, int, int]:
"""
Load a dataset and return deterministic train/val subsets.
Args:
dataset_name: One of 'mnist', 'fashion_mnist', 'cifar10'
seed: Random seed for reproducible subsetting
augment: Whether to apply data augmentation to training set
aug_strength: Augmentation intensity (0.0 to 1.0)
Returns:
(train_dataset, val_dataset, num_train, num_val)
"""
generator = torch.Generator().manual_seed(seed)
if dataset_name == "mnist":
transform_train = _get_mnist_transforms(augment, aug_strength)
transform_val = _get_mnist_transforms(augment=False)
full_train = torchvision.datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform_train)
full_val = torchvision.datasets.MNIST(DATA_DIR, train=True, download=True, transform=transform_val)
total_subset = 5000
train_size, val_size = 4000, 1000
elif dataset_name == "fashion_mnist":
transform_train = _get_fashion_transforms(augment, aug_strength)
transform_val = _get_fashion_transforms(augment=False)
full_train = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform_train)
full_val = torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True, transform=transform_val)
total_subset = 8000
train_size, val_size = 6500, 1500
elif dataset_name == "cifar10":
transform_train = _get_cifar_transforms(augment, aug_strength)
transform_val = _get_cifar_transforms(augment=False)
full_train = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True, transform=transform_train)
full_val = torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True, transform=transform_val)
total_subset = 10000
train_size, val_size = 8000, 2000
else:
raise ValueError(f"Unknown dataset: {dataset_name}. Choose from: mnist, fashion_mnist, cifar10")
# Create deterministic subset indices
all_indices = torch.randperm(len(full_train), generator=generator)[:total_subset].tolist()
train_indices = all_indices[:train_size]
val_indices = all_indices[train_size:train_size + val_size]
train_dataset = Subset(full_train, train_indices)
val_dataset = Subset(full_val, val_indices)
return train_dataset, val_dataset, train_size, val_size
def create_dataloaders(
dataset_name: str,
batch_size: int = 64,
seed: int = 42,
augment: bool = False,
aug_strength: float = 0.5,
) -> Tuple[DataLoader, DataLoader, int, int]:
"""
Create DataLoaders for a dataset.
Returns:
(train_loader, val_loader, num_train, num_val)
"""
train_ds, val_ds, n_train, n_val = load_dataset(
dataset_name, seed=seed, augment=augment, aug_strength=aug_strength
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
shuffle=True,
num_workers=0, # Keep it simple for 2 vCPU
pin_memory=False,
generator=torch.Generator().manual_seed(seed),
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size * 2, # Larger batch for eval (no grads)
shuffle=False,
num_workers=0,
)
return train_loader, val_loader, n_train, n_val
def download_all_datasets():
"""Pre-download all datasets. Called during Docker build."""
print("Downloading MNIST...")
torchvision.datasets.MNIST(DATA_DIR, train=True, download=True)
print("Downloading FashionMNIST...")
torchvision.datasets.FashionMNIST(DATA_DIR, train=True, download=True)
print("Downloading CIFAR-10...")
torchvision.datasets.CIFAR10(DATA_DIR, train=True, download=True)
print("All datasets downloaded.")
if __name__ == "__main__":
download_all_datasets()