from typing import List, Tuple import torch from PIL import Image from datasets import load_dataset from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED _CLASS_NAMES = None _HF_DATASET_CACHE = None class HFDatasetWrapper(Dataset): def __init__(self, hf_dataset, transform): self.dataset = hf_dataset self.transform = transform def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] image = item["image"] if not isinstance(image, Image.Image): image = Image.open(image) image = image.convert("RGB") label = int(item["label"]) return self.transform(image), label def get_transform(): return transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ] ) def load_charcoal_dataset(): global _CLASS_NAMES, _HF_DATASET_CACHE if _HF_DATASET_CACHE is not None: return _HF_DATASET_CACHE, _CLASS_NAMES if not HF_TOKEN: raise RuntimeError( "HF_TOKEN is missing. Please add it in the Space secrets." ) raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN) label_feature = raw["train"].features["label"] if hasattr(label_feature, "names"): _CLASS_NAMES = label_feature.names else: _CLASS_NAMES = sorted(list(set(raw["train"]["label"]))) if "test" not in raw: try: split = raw["train"].train_test_split( test_size=0.2, seed=RANDOM_SEED, stratify_by_column="label", ) except Exception: split = raw["train"].train_test_split( test_size=0.2, seed=RANDOM_SEED, ) raw = { "train": split["train"], "test": split["test"], } _HF_DATASET_CACHE = raw return _HF_DATASET_CACHE, _CLASS_NAMES def get_class_names() -> List[str]: _, class_names = load_charcoal_dataset() return class_names def make_loaders(batch_size: int, val_ratio: float = 0.1): raw, class_names = load_charcoal_dataset() transform = get_transform() train_dataset = HFDatasetWrapper(raw["train"], transform) test_dataset = HFDatasetWrapper(raw["test"], transform) val_size = int(len(train_dataset) * val_ratio) train_size = len(train_dataset) - val_size train_subset, val_subset = random_split( train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(RANDOM_SEED), ) train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) return train_loader, val_loader, test_loader, class_names