import random from collections import Counter from typing import Dict, List, Tuple import pandas as pd import torch from PIL import Image from datasets import load_dataset, DatasetDict from torch.utils.data import Dataset, DataLoader, Subset from torchvision import transforms from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED _RAW_DATASET = None _CLASS_NAMES = None _SPLITS = None class HFDatasetWrapper(Dataset): def __init__(self, hf_dataset, transform): self.dataset = hf_dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] image = item["image"] if not isinstance(image, Image.Image): image = Image.open(image) image = image.convert("RGB") label = int(item["label"]) if self.transform: image = self.transform(image) return image, label def get_train_transform(): return transforms.Compose( [ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomRotation(degrees=5), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ] ) def get_eval_transform(): return transforms.Compose( [ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ] ) def load_raw_dataset(): global _RAW_DATASET, _CLASS_NAMES if _RAW_DATASET is not None: return _RAW_DATASET, _CLASS_NAMES if not HF_TOKEN: raise RuntimeError( "HF_TOKEN est manquant. Ajoutez-le dans les Secrets du Space Hugging Face." ) raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN) if "train" not in raw: raise RuntimeError("Le dataset Hugging Face doit contenir au moins un split 'train'.") label_feature = raw["train"].features["label"] if hasattr(label_feature, "names") and label_feature.names: class_names = label_feature.names else: labels = list(raw["train"]["label"]) class_names = [str(x) for x in sorted(set(labels))] _RAW_DATASET = raw _CLASS_NAMES = class_names return _RAW_DATASET, _CLASS_NAMES def prepare_splits( train_ratio: float = 0.70, val_ratio: float = 0.15, test_ratio: float = 0.15, ): global _SPLITS if _SPLITS is not None: return _SPLITS raw, class_names = load_raw_dataset() if "validation" in raw and "test" in raw: _SPLITS = { "train": raw["train"], "validation": raw["validation"], "test": raw["test"], } return _SPLITS if "test" in raw: train_val = raw["train"] test = raw["test"] relative_val_ratio = val_ratio / (train_ratio + val_ratio) try: split_train_val = train_val.train_test_split( test_size=relative_val_ratio, seed=RANDOM_SEED, stratify_by_column="label", ) except Exception: split_train_val = train_val.train_test_split( test_size=relative_val_ratio, seed=RANDOM_SEED, ) _SPLITS = { "train": split_train_val["train"], "validation": split_train_val["test"], "test": test, } return _SPLITS full = raw["train"] try: first_split = full.train_test_split( test_size=(val_ratio + test_ratio), seed=RANDOM_SEED, stratify_by_column="label", ) except Exception: first_split = full.train_test_split( test_size=(val_ratio + test_ratio), seed=RANDOM_SEED, ) temp = first_split["test"] relative_test_ratio = test_ratio / (val_ratio + test_ratio) try: second_split = temp.train_test_split( test_size=relative_test_ratio, seed=RANDOM_SEED, stratify_by_column="label", ) except Exception: second_split = temp.train_test_split( test_size=relative_test_ratio, seed=RANDOM_SEED, ) _SPLITS = { "train": first_split["train"], "validation": second_split["train"], "test": second_split["test"], } return _SPLITS def get_class_names() -> List[str]: _, class_names = load_raw_dataset() return class_names def make_loaders(batch_size: int): splits = prepare_splits() class_names = get_class_names() train_dataset = HFDatasetWrapper(splits["train"], get_train_transform()) val_dataset = HFDatasetWrapper(splits["validation"], get_eval_transform()) test_dataset = HFDatasetWrapper(splits["test"], get_eval_transform()) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, val_loader, test_loader, class_names def dataset_overview() -> Tuple[dict, pd.DataFrame]: splits = prepare_splits() class_names = get_class_names() rows = [] total = 0 for split_name, split_data in splits.items(): labels = list(split_data["label"]) counter = Counter(labels) split_total = len(labels) total += split_total for label_id, count in sorted(counter.items()): rows.append( { "split": split_name, "classe": class_names[int(label_id)], "nombre_images": count, } ) df = pd.DataFrame(rows) summary = { "dataset": HF_DATASET_REPO, "nombre_total_images": total, "nombre_classes": len(class_names), "train": len(splits["train"]), "validation": len(splits["validation"]), "test": len(splits["test"]), } return summary, df def get_images_for_gallery(split_name: str, class_name: str, max_images: int = 24): splits = prepare_splits() class_names = get_class_names() if split_name not in splits: split_name = "train" dataset = splits[split_name] if class_name and class_name != "Toutes les classes": class_id = class_names.index(class_name) indices = [i for i, x in enumerate(dataset["label"]) if int(x) == class_id] else: indices = list(range(len(dataset))) if not indices: return [] sample_indices = random.sample(indices, min(max_images, len(indices))) gallery = [] for idx in sample_indices: item = dataset[idx] image = item["image"] if not isinstance(image, Image.Image): image = Image.open(image) image = image.convert("RGB") label_id = int(item["label"]) label_name = class_names[label_id] gallery.append((image, label_name)) return gallery