Spaces:
Sleeping
Sleeping
| import random | |
| from collections import Counter | |
| from typing import Dict, List, Tuple | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from datasets import load_dataset, DatasetDict | |
| from torch.utils.data import Dataset, DataLoader, Subset | |
| from torchvision import transforms | |
| from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED | |
| _RAW_DATASET = None | |
| _CLASS_NAMES = None | |
| _SPLITS = None | |
| class HFDatasetWrapper(Dataset): | |
| def __init__(self, hf_dataset, transform): | |
| self.dataset = hf_dataset | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| image = item["image"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| label = int(item["label"]) | |
| if self.transform: | |
| image = self.transform(image) | |
| return image, label | |
| def get_train_transform(): | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.RandomHorizontalFlip(p=0.5), | |
| transforms.RandomVerticalFlip(p=0.5), | |
| transforms.RandomRotation(degrees=5), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| ), | |
| ] | |
| ) | |
| def get_eval_transform(): | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| ), | |
| ] | |
| ) | |
| def load_raw_dataset(): | |
| global _RAW_DATASET, _CLASS_NAMES | |
| if _RAW_DATASET is not None: | |
| return _RAW_DATASET, _CLASS_NAMES | |
| if not HF_TOKEN: | |
| raise RuntimeError( | |
| "HF_TOKEN est manquant. Ajoutez-le dans les Secrets du Space Hugging Face." | |
| ) | |
| raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN) | |
| if "train" not in raw: | |
| raise RuntimeError("Le dataset Hugging Face doit contenir au moins un split 'train'.") | |
| label_feature = raw["train"].features["label"] | |
| if hasattr(label_feature, "names") and label_feature.names: | |
| class_names = label_feature.names | |
| else: | |
| labels = list(raw["train"]["label"]) | |
| class_names = [str(x) for x in sorted(set(labels))] | |
| _RAW_DATASET = raw | |
| _CLASS_NAMES = class_names | |
| return _RAW_DATASET, _CLASS_NAMES | |
| def prepare_splits( | |
| train_ratio: float = 0.70, | |
| val_ratio: float = 0.15, | |
| test_ratio: float = 0.15, | |
| ): | |
| global _SPLITS | |
| if _SPLITS is not None: | |
| return _SPLITS | |
| raw, class_names = load_raw_dataset() | |
| if "validation" in raw and "test" in raw: | |
| _SPLITS = { | |
| "train": raw["train"], | |
| "validation": raw["validation"], | |
| "test": raw["test"], | |
| } | |
| return _SPLITS | |
| if "test" in raw: | |
| train_val = raw["train"] | |
| test = raw["test"] | |
| relative_val_ratio = val_ratio / (train_ratio + val_ratio) | |
| try: | |
| split_train_val = train_val.train_test_split( | |
| test_size=relative_val_ratio, | |
| seed=RANDOM_SEED, | |
| stratify_by_column="label", | |
| ) | |
| except Exception: | |
| split_train_val = train_val.train_test_split( | |
| test_size=relative_val_ratio, | |
| seed=RANDOM_SEED, | |
| ) | |
| _SPLITS = { | |
| "train": split_train_val["train"], | |
| "validation": split_train_val["test"], | |
| "test": test, | |
| } | |
| return _SPLITS | |
| full = raw["train"] | |
| try: | |
| first_split = full.train_test_split( | |
| test_size=(val_ratio + test_ratio), | |
| seed=RANDOM_SEED, | |
| stratify_by_column="label", | |
| ) | |
| except Exception: | |
| first_split = full.train_test_split( | |
| test_size=(val_ratio + test_ratio), | |
| seed=RANDOM_SEED, | |
| ) | |
| temp = first_split["test"] | |
| relative_test_ratio = test_ratio / (val_ratio + test_ratio) | |
| try: | |
| second_split = temp.train_test_split( | |
| test_size=relative_test_ratio, | |
| seed=RANDOM_SEED, | |
| stratify_by_column="label", | |
| ) | |
| except Exception: | |
| second_split = temp.train_test_split( | |
| test_size=relative_test_ratio, | |
| seed=RANDOM_SEED, | |
| ) | |
| _SPLITS = { | |
| "train": first_split["train"], | |
| "validation": second_split["train"], | |
| "test": second_split["test"], | |
| } | |
| return _SPLITS | |
| def get_class_names() -> List[str]: | |
| _, class_names = load_raw_dataset() | |
| return class_names | |
| def make_loaders(batch_size: int): | |
| splits = prepare_splits() | |
| class_names = get_class_names() | |
| train_dataset = HFDatasetWrapper(splits["train"], get_train_transform()) | |
| val_dataset = HFDatasetWrapper(splits["validation"], get_eval_transform()) | |
| test_dataset = HFDatasetWrapper(splits["test"], get_eval_transform()) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, val_loader, test_loader, class_names | |
| def dataset_overview() -> Tuple[dict, pd.DataFrame]: | |
| splits = prepare_splits() | |
| class_names = get_class_names() | |
| rows = [] | |
| total = 0 | |
| for split_name, split_data in splits.items(): | |
| labels = list(split_data["label"]) | |
| counter = Counter(labels) | |
| split_total = len(labels) | |
| total += split_total | |
| for label_id, count in sorted(counter.items()): | |
| rows.append( | |
| { | |
| "split": split_name, | |
| "classe": class_names[int(label_id)], | |
| "nombre_images": count, | |
| } | |
| ) | |
| df = pd.DataFrame(rows) | |
| summary = { | |
| "dataset": HF_DATASET_REPO, | |
| "nombre_total_images": total, | |
| "nombre_classes": len(class_names), | |
| "train": len(splits["train"]), | |
| "validation": len(splits["validation"]), | |
| "test": len(splits["test"]), | |
| } | |
| return summary, df | |
| def get_images_for_gallery(split_name: str, class_name: str, max_images: int = 24): | |
| splits = prepare_splits() | |
| class_names = get_class_names() | |
| if split_name not in splits: | |
| split_name = "train" | |
| dataset = splits[split_name] | |
| if class_name and class_name != "Toutes les classes": | |
| class_id = class_names.index(class_name) | |
| indices = [i for i, x in enumerate(dataset["label"]) if int(x) == class_id] | |
| else: | |
| indices = list(range(len(dataset))) | |
| if not indices: | |
| return [] | |
| sample_indices = random.sample(indices, min(max_images, len(indices))) | |
| gallery = [] | |
| for idx in sample_indices: | |
| item = dataset[idx] | |
| image = item["image"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| label_id = int(item["label"]) | |
| label_name = class_names[label_id] | |
| gallery.append((image, label_name)) | |
| return gallery |