Spaces:
Sleeping
Sleeping
| from typing import List, Tuple | |
| import torch | |
| from PIL import Image | |
| from datasets import load_dataset | |
| from torch.utils.data import Dataset, DataLoader, random_split | |
| from torchvision import transforms | |
| from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED | |
| _CLASS_NAMES = None | |
| _HF_DATASET_CACHE = None | |
| class HFDatasetWrapper(Dataset): | |
| def __init__(self, hf_dataset, transform): | |
| self.dataset = hf_dataset | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = self.dataset[idx] | |
| image = item["image"] | |
| if not isinstance(image, Image.Image): | |
| image = Image.open(image) | |
| image = image.convert("RGB") | |
| label = int(item["label"]) | |
| return self.transform(image), label | |
| def get_transform(): | |
| return transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| ), | |
| ] | |
| ) | |
| def load_charcoal_dataset(): | |
| global _CLASS_NAMES, _HF_DATASET_CACHE | |
| if _HF_DATASET_CACHE is not None: | |
| return _HF_DATASET_CACHE, _CLASS_NAMES | |
| if not HF_TOKEN: | |
| raise RuntimeError( | |
| "HF_TOKEN is missing. Please add it in the Space secrets." | |
| ) | |
| raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN) | |
| label_feature = raw["train"].features["label"] | |
| if hasattr(label_feature, "names"): | |
| _CLASS_NAMES = label_feature.names | |
| else: | |
| _CLASS_NAMES = sorted(list(set(raw["train"]["label"]))) | |
| if "test" not in raw: | |
| try: | |
| split = raw["train"].train_test_split( | |
| test_size=0.2, | |
| seed=RANDOM_SEED, | |
| stratify_by_column="label", | |
| ) | |
| except Exception: | |
| split = raw["train"].train_test_split( | |
| test_size=0.2, | |
| seed=RANDOM_SEED, | |
| ) | |
| raw = { | |
| "train": split["train"], | |
| "test": split["test"], | |
| } | |
| _HF_DATASET_CACHE = raw | |
| return _HF_DATASET_CACHE, _CLASS_NAMES | |
| def get_class_names() -> List[str]: | |
| _, class_names = load_charcoal_dataset() | |
| return class_names | |
| def make_loaders(batch_size: int, val_ratio: float = 0.1): | |
| raw, class_names = load_charcoal_dataset() | |
| transform = get_transform() | |
| train_dataset = HFDatasetWrapper(raw["train"], transform) | |
| test_dataset = HFDatasetWrapper(raw["test"], transform) | |
| val_size = int(len(train_dataset) * val_ratio) | |
| train_size = len(train_dataset) - val_size | |
| train_subset, val_subset = random_split( | |
| train_dataset, | |
| [train_size, val_size], | |
| generator=torch.Generator().manual_seed(RANDOM_SEED), | |
| ) | |
| train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True) | |
| val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False) | |
| test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
| return train_loader, val_loader, test_loader, class_names |