File size: 3,157 Bytes
586661b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a15372
586661b
 
2a15372
 
586661b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import List, Tuple

import torch
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED


_CLASS_NAMES = None
_HF_DATASET_CACHE = None


class HFDatasetWrapper(Dataset):
    def __init__(self, hf_dataset, transform):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        image = item["image"]
        if not isinstance(image, Image.Image):
            image = Image.open(image)

        image = image.convert("RGB")
        label = int(item["label"])

        return self.transform(image), label


def get_transform():
    return transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225),
            ),
        ]
    )


def load_charcoal_dataset():
    global _CLASS_NAMES, _HF_DATASET_CACHE

    if _HF_DATASET_CACHE is not None:
        return _HF_DATASET_CACHE, _CLASS_NAMES

    if not HF_TOKEN:
        raise RuntimeError(
            "HF_TOKEN is missing. Please add it in the Space secrets."
        )

    raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)

    label_feature = raw["train"].features["label"]
    if hasattr(label_feature, "names"):
        _CLASS_NAMES = label_feature.names
    else:
        _CLASS_NAMES = sorted(list(set(raw["train"]["label"])))

    if "test" not in raw:
        try:
            split = raw["train"].train_test_split(
                test_size=0.2,
                seed=RANDOM_SEED,
                stratify_by_column="label",
            )
        except Exception:
            split = raw["train"].train_test_split(
                test_size=0.2,
                seed=RANDOM_SEED,
            )

        raw = {
            "train": split["train"],
            "test": split["test"],
        }

    _HF_DATASET_CACHE = raw
    return _HF_DATASET_CACHE, _CLASS_NAMES


def get_class_names() -> List[str]:
    _, class_names = load_charcoal_dataset()
    return class_names


def make_loaders(batch_size: int, val_ratio: float = 0.1):
    raw, class_names = load_charcoal_dataset()
    transform = get_transform()

    train_dataset = HFDatasetWrapper(raw["train"], transform)
    test_dataset = HFDatasetWrapper(raw["test"], transform)

    val_size = int(len(train_dataset) * val_ratio)
    train_size = len(train_dataset) - val_size

    train_subset, val_subset = random_split(
        train_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(RANDOM_SEED),
    )

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, class_names