File size: 3,279 Bytes
fab18b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
Data loading utilities.

Supports:
  - Stanford Cars dataset (download via Kaggle or torchvision)
  - Any folder of images (ImageFolder-style)
"""

from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


def get_transforms(image_size: int = 64) -> transforms.Compose:
    """Standard augmentation pipeline for GAN training."""
    return transforms.Compose([
        transforms.Resize(int(image_size * 1.12)),   # slightly larger then crop
        transforms.RandomCrop(image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),   # β†’ [-1, 1]
    ])


def get_inference_transforms(image_size: int = 64) -> transforms.Compose:
    """No augmentation – for evaluation or reference images."""
    return transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])


def build_dataloader(
    data_path: str | Path,
    image_size: int = 64,
    batch_size: int = 64,
    num_workers: int = 4,
    pin_memory: bool = True,
) -> DataLoader:
    """
    Build a DataLoader from a flat image folder.

    Expected structure:
        data_path/
          class_a/  ← can be a single dummy class "cars/"
            img001.jpg
            img002.jpg
            ...

    Stanford Cars tip: point data_path at the 'cars_train' directory;
    it already has per-class sub-folders which are ignored by GAN training.
    """
    dataset = datasets.ImageFolder(
        root=str(data_path),
        transform=get_transforms(image_size),
    )
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,   # avoids batch-norm issues with small last batches
    )
    print(f"[Data] {len(dataset):,} images | {len(loader):,} batches | batch_size={batch_size}")
    return loader


# ──────────────────────────────────────────
# Helper: download Stanford Cars via Kaggle
# ──────────────────────────────────────────
def download_stanford_cars(dest: str = "./data") -> None:
    """
    Downloads the Stanford Cars dataset using the Kaggle API.

    Prerequisites:
        pip install kaggle
        Place kaggle.json in ~/.kaggle/
    Run:
        python -c "from src.data.dataset import download_stanford_cars; download_stanford_cars()"
    """
    import subprocess, sys, os
    dest_path = Path(dest)
    dest_path.mkdir(parents=True, exist_ok=True)

    cmd = [
        sys.executable, "-m", "kaggle", "datasets", "download",
        "-d", "jutrera/stanford-car-dataset-by-classes-folder",
        "-p", str(dest_path), "--unzip",
    ]
    print("[Data] Downloading Stanford Cars dataset via Kaggle...")
    subprocess.run(cmd, check=True)
    print(f"[Data] Done. Dataset saved to {dest_path}")