CircleStar commited on
Commit
c3905ef
·
verified ·
1 Parent(s): c47b92a

Update data_utils.py

Browse files
Files changed (1) hide show
  1. data_utils.py +206 -42
data_utils.py CHANGED
@@ -1,16 +1,20 @@
1
- from typing import List, Tuple
 
 
2
 
 
3
  import torch
4
  from PIL import Image
5
- from datasets import load_dataset
6
- from torch.utils.data import Dataset, DataLoader, random_split
7
  from torchvision import transforms
8
 
9
  from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED
10
 
11
 
 
12
  _CLASS_NAMES = None
13
- _HF_DATASET_CACHE = None
14
 
15
 
16
  class HFDatasetWrapper(Dataset):
@@ -31,13 +35,24 @@ class HFDatasetWrapper(Dataset):
31
  image = image.convert("RGB")
32
  label = int(item["label"])
33
 
34
- return self.transform(image), label
 
35
 
 
36
 
37
- def get_transform():
 
38
  return transforms.Compose(
39
  [
40
- transforms.Resize((224, 224)),
 
 
 
 
 
 
 
 
41
  transforms.ToTensor(),
42
  transforms.Normalize(
43
  mean=(0.485, 0.456, 0.406),
@@ -47,70 +62,219 @@ def get_transform():
47
  )
48
 
49
 
50
- def load_charcoal_dataset():
51
- global _CLASS_NAMES, _HF_DATASET_CACHE
 
 
 
 
 
 
 
 
 
 
52
 
53
- if _HF_DATASET_CACHE is not None:
54
- return _HF_DATASET_CACHE, _CLASS_NAMES
 
 
 
55
 
56
  if not HF_TOKEN:
57
  raise RuntimeError(
58
- "HF_TOKEN is missing. Please add it in the Space secrets."
59
  )
60
 
61
  raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
62
 
 
 
 
63
  label_feature = raw["train"].features["label"]
64
- if hasattr(label_feature, "names"):
65
- _CLASS_NAMES = label_feature.names
 
66
  else:
67
- _CLASS_NAMES = sorted(list(set(raw["train"]["label"])))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- if "test" not in raw:
70
  try:
71
- split = raw["train"].train_test_split(
72
- test_size=0.2,
73
  seed=RANDOM_SEED,
74
  stratify_by_column="label",
75
  )
76
  except Exception:
77
- split = raw["train"].train_test_split(
78
- test_size=0.2,
79
  seed=RANDOM_SEED,
80
  )
81
 
82
- raw = {
83
- "train": split["train"],
84
- "test": split["test"],
 
85
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- _HF_DATASET_CACHE = raw
88
- return _HF_DATASET_CACHE, _CLASS_NAMES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  def get_class_names() -> List[str]:
92
- _, class_names = load_charcoal_dataset()
93
  return class_names
94
 
95
 
96
- def make_loaders(batch_size: int, val_ratio: float = 0.1):
97
- raw, class_names = load_charcoal_dataset()
98
- transform = get_transform()
99
 
100
- train_dataset = HFDatasetWrapper(raw["train"], transform)
101
- test_dataset = HFDatasetWrapper(raw["test"], transform)
 
102
 
103
- val_size = int(len(train_dataset) * val_ratio)
104
- train_size = len(train_dataset) - val_size
 
105
 
106
- train_subset, val_subset = random_split(
107
- train_dataset,
108
- [train_size, val_size],
109
- generator=torch.Generator().manual_seed(RANDOM_SEED),
110
- )
111
 
112
- train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
113
- val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
114
- test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
115
 
116
- return train_loader, val_loader, test_loader, class_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import Counter
3
+ from typing import Dict, List, Tuple
4
 
5
+ import pandas as pd
6
  import torch
7
  from PIL import Image
8
+ from datasets import load_dataset, DatasetDict
9
+ from torch.utils.data import Dataset, DataLoader, Subset
10
  from torchvision import transforms
11
 
12
  from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED
13
 
14
 
15
+ _RAW_DATASET = None
16
  _CLASS_NAMES = None
17
+ _SPLITS = None
18
 
19
 
20
  class HFDatasetWrapper(Dataset):
 
35
  image = image.convert("RGB")
36
  label = int(item["label"])
37
 
38
+ if self.transform:
39
+ image = self.transform(image)
40
 
41
+ return image, label
42
 
43
+
44
+ def get_train_transform():
45
  return transforms.Compose(
46
  [
47
+ transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.65, 1.0)),
48
+ transforms.RandomHorizontalFlip(p=0.5),
49
+ transforms.RandomVerticalFlip(p=0.2),
50
+ transforms.RandomRotation(degrees=15),
51
+ transforms.ColorJitter(
52
+ brightness=0.2,
53
+ contrast=0.2,
54
+ saturation=0.1,
55
+ ),
56
  transforms.ToTensor(),
57
  transforms.Normalize(
58
  mean=(0.485, 0.456, 0.406),
 
62
  )
63
 
64
 
65
+ def get_eval_transform():
66
+ return transforms.Compose(
67
+ [
68
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize(
71
+ mean=(0.485, 0.456, 0.406),
72
+ std=(0.229, 0.224, 0.225),
73
+ ),
74
+ ]
75
+ )
76
+
77
 
78
+ def load_raw_dataset():
79
+ global _RAW_DATASET, _CLASS_NAMES
80
+
81
+ if _RAW_DATASET is not None:
82
+ return _RAW_DATASET, _CLASS_NAMES
83
 
84
  if not HF_TOKEN:
85
  raise RuntimeError(
86
+ "HF_TOKEN est manquant. Ajoutez-le dans les Secrets du Space Hugging Face."
87
  )
88
 
89
  raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
90
 
91
+ if "train" not in raw:
92
+ raise RuntimeError("Le dataset Hugging Face doit contenir au moins un split 'train'.")
93
+
94
  label_feature = raw["train"].features["label"]
95
+
96
+ if hasattr(label_feature, "names") and label_feature.names:
97
+ class_names = label_feature.names
98
  else:
99
+ labels = list(raw["train"]["label"])
100
+ class_names = [str(x) for x in sorted(set(labels))]
101
+
102
+ _RAW_DATASET = raw
103
+ _CLASS_NAMES = class_names
104
+
105
+ return _RAW_DATASET, _CLASS_NAMES
106
+
107
+
108
+ def prepare_splits(
109
+ train_ratio: float = 0.80,
110
+ val_ratio: float = 0.10,
111
+ test_ratio: float = 0.10,
112
+ ):
113
+ global _SPLITS
114
+
115
+ if _SPLITS is not None:
116
+ return _SPLITS
117
+
118
+ raw, class_names = load_raw_dataset()
119
+
120
+ if "validation" in raw and "test" in raw:
121
+ _SPLITS = {
122
+ "train": raw["train"],
123
+ "validation": raw["validation"],
124
+ "test": raw["test"],
125
+ }
126
+ return _SPLITS
127
+
128
+ if "test" in raw:
129
+ train_val = raw["train"]
130
+ test = raw["test"]
131
+
132
+ relative_val_ratio = val_ratio / (train_ratio + val_ratio)
133
 
 
134
  try:
135
+ split_train_val = train_val.train_test_split(
136
+ test_size=relative_val_ratio,
137
  seed=RANDOM_SEED,
138
  stratify_by_column="label",
139
  )
140
  except Exception:
141
+ split_train_val = train_val.train_test_split(
142
+ test_size=relative_val_ratio,
143
  seed=RANDOM_SEED,
144
  )
145
 
146
+ _SPLITS = {
147
+ "train": split_train_val["train"],
148
+ "validation": split_train_val["test"],
149
+ "test": test,
150
  }
151
+ return _SPLITS
152
+
153
+ full = raw["train"]
154
+
155
+ try:
156
+ first_split = full.train_test_split(
157
+ test_size=(val_ratio + test_ratio),
158
+ seed=RANDOM_SEED,
159
+ stratify_by_column="label",
160
+ )
161
+ except Exception:
162
+ first_split = full.train_test_split(
163
+ test_size=(val_ratio + test_ratio),
164
+ seed=RANDOM_SEED,
165
+ )
166
+
167
+ temp = first_split["test"]
168
+ relative_test_ratio = test_ratio / (val_ratio + test_ratio)
169
 
170
+ try:
171
+ second_split = temp.train_test_split(
172
+ test_size=relative_test_ratio,
173
+ seed=RANDOM_SEED,
174
+ stratify_by_column="label",
175
+ )
176
+ except Exception:
177
+ second_split = temp.train_test_split(
178
+ test_size=relative_test_ratio,
179
+ seed=RANDOM_SEED,
180
+ )
181
+
182
+ _SPLITS = {
183
+ "train": first_split["train"],
184
+ "validation": second_split["train"],
185
+ "test": second_split["test"],
186
+ }
187
+
188
+ return _SPLITS
189
 
190
 
191
  def get_class_names() -> List[str]:
192
+ _, class_names = load_raw_dataset()
193
  return class_names
194
 
195
 
196
+ def make_loaders(batch_size: int):
197
+ splits = prepare_splits()
198
+ class_names = get_class_names()
199
 
200
+ train_dataset = HFDatasetWrapper(splits["train"], get_train_transform())
201
+ val_dataset = HFDatasetWrapper(splits["validation"], get_eval_transform())
202
+ test_dataset = HFDatasetWrapper(splits["test"], get_eval_transform())
203
 
204
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
205
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
206
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
207
 
208
+ return train_loader, val_loader, test_loader, class_names
 
 
 
 
209
 
 
 
 
210
 
211
+ def dataset_overview() -> Tuple[dict, pd.DataFrame]:
212
+ splits = prepare_splits()
213
+ class_names = get_class_names()
214
+
215
+ rows = []
216
+ total = 0
217
+
218
+ for split_name, split_data in splits.items():
219
+ labels = list(split_data["label"])
220
+ counter = Counter(labels)
221
+ split_total = len(labels)
222
+ total += split_total
223
+
224
+ for label_id, count in sorted(counter.items()):
225
+ rows.append(
226
+ {
227
+ "split": split_name,
228
+ "classe": class_names[int(label_id)],
229
+ "nombre_images": count,
230
+ }
231
+ )
232
+
233
+ df = pd.DataFrame(rows)
234
+
235
+ summary = {
236
+ "dataset": HF_DATASET_REPO,
237
+ "nombre_total_images": total,
238
+ "nombre_classes": len(class_names),
239
+ "train": len(splits["train"]),
240
+ "validation": len(splits["validation"]),
241
+ "test": len(splits["test"]),
242
+ }
243
+
244
+ return summary, df
245
+
246
+
247
+ def get_images_for_gallery(split_name: str, class_name: str, max_images: int = 24):
248
+ splits = prepare_splits()
249
+ class_names = get_class_names()
250
+
251
+ if split_name not in splits:
252
+ split_name = "train"
253
+
254
+ dataset = splits[split_name]
255
+
256
+ if class_name and class_name != "Toutes les classes":
257
+ class_id = class_names.index(class_name)
258
+ indices = [i for i, x in enumerate(dataset["label"]) if int(x) == class_id]
259
+ else:
260
+ indices = list(range(len(dataset)))
261
+
262
+ if not indices:
263
+ return []
264
+
265
+ sample_indices = random.sample(indices, min(max_images, len(indices)))
266
+
267
+ gallery = []
268
+ for idx in sample_indices:
269
+ item = dataset[idx]
270
+ image = item["image"]
271
+ if not isinstance(image, Image.Image):
272
+ image = Image.open(image)
273
+ image = image.convert("RGB")
274
+
275
+ label_id = int(item["label"])
276
+ label_name = class_names[label_id]
277
+
278
+ gallery.append((image, label_name))
279
+
280
+ return gallery