CircleStar commited on
Commit
586661b
·
verified ·
1 Parent(s): 01ce719

Create data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +116 -0
data_utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from datasets import load_dataset
6
+ from torch.utils.data import Dataset, DataLoader, random_split
7
+ from torchvision import transforms
8
+
9
+ from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED
10
+
11
+
12
+ _CLASS_NAMES = None
13
+ _HF_DATASET_CACHE = None
14
+
15
+
16
+ class HFDatasetWrapper(Dataset):
17
+ def __init__(self, hf_dataset, transform):
18
+ self.dataset = hf_dataset
19
+ self.transform = transform
20
+
21
+ def __len__(self):
22
+ return len(self.dataset)
23
+
24
+ def __getitem__(self, idx):
25
+ item = self.dataset[idx]
26
+
27
+ image = item["image"]
28
+ if not isinstance(image, Image.Image):
29
+ image = Image.open(image)
30
+
31
+ image = image.convert("RGB")
32
+ label = int(item["label"])
33
+
34
+ return self.transform(image), label
35
+
36
+
37
+ def get_transform():
38
+ return transforms.Compose(
39
+ [
40
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(
43
+ mean=(0.5, 0.5, 0.5),
44
+ std=(0.5, 0.5, 0.5),
45
+ ),
46
+ ]
47
+ )
48
+
49
+
50
+ def load_charcoal_dataset():
51
+ global _CLASS_NAMES, _HF_DATASET_CACHE
52
+
53
+ if _HF_DATASET_CACHE is not None:
54
+ return _HF_DATASET_CACHE, _CLASS_NAMES
55
+
56
+ if not HF_TOKEN:
57
+ raise RuntimeError(
58
+ "HF_TOKEN is missing. Please add it in the Space secrets."
59
+ )
60
+
61
+ raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
62
+
63
+ label_feature = raw["train"].features["label"]
64
+ if hasattr(label_feature, "names"):
65
+ _CLASS_NAMES = label_feature.names
66
+ else:
67
+ _CLASS_NAMES = sorted(list(set(raw["train"]["label"])))
68
+
69
+ if "test" not in raw:
70
+ try:
71
+ split = raw["train"].train_test_split(
72
+ test_size=0.2,
73
+ seed=RANDOM_SEED,
74
+ stratify_by_column="label",
75
+ )
76
+ except Exception:
77
+ split = raw["train"].train_test_split(
78
+ test_size=0.2,
79
+ seed=RANDOM_SEED,
80
+ )
81
+
82
+ raw = {
83
+ "train": split["train"],
84
+ "test": split["test"],
85
+ }
86
+
87
+ _HF_DATASET_CACHE = raw
88
+ return _HF_DATASET_CACHE, _CLASS_NAMES
89
+
90
+
91
+ def get_class_names() -> List[str]:
92
+ _, class_names = load_charcoal_dataset()
93
+ return class_names
94
+
95
+
96
+ def make_loaders(batch_size: int, val_ratio: float = 0.1):
97
+ raw, class_names = load_charcoal_dataset()
98
+ transform = get_transform()
99
+
100
+ train_dataset = HFDatasetWrapper(raw["train"], transform)
101
+ test_dataset = HFDatasetWrapper(raw["test"], transform)
102
+
103
+ val_size = int(len(train_dataset) * val_ratio)
104
+ train_size = len(train_dataset) - val_size
105
+
106
+ train_subset, val_subset = random_split(
107
+ train_dataset,
108
+ [train_size, val_size],
109
+ generator=torch.Generator().manual_seed(RANDOM_SEED),
110
+ )
111
+
112
+ train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
113
+ val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
114
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
115
+
116
+ return train_loader, val_loader, test_loader, class_names