Image_Classification / data_utils.py
CircleStar's picture
Update data_utils.py
2a15372 verified
raw
history blame
3.16 kB
from typing import List, Tuple
import torch
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from config import HF_DATASET_REPO, HF_TOKEN, IMAGE_SIZE, RANDOM_SEED
_CLASS_NAMES = None
_HF_DATASET_CACHE = None
class HFDatasetWrapper(Dataset):
def __init__(self, hf_dataset, transform):
self.dataset = hf_dataset
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item["image"]
if not isinstance(image, Image.Image):
image = Image.open(image)
image = image.convert("RGB")
label = int(item["label"])
return self.transform(image), label
def get_transform():
return transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
def load_charcoal_dataset():
global _CLASS_NAMES, _HF_DATASET_CACHE
if _HF_DATASET_CACHE is not None:
return _HF_DATASET_CACHE, _CLASS_NAMES
if not HF_TOKEN:
raise RuntimeError(
"HF_TOKEN is missing. Please add it in the Space secrets."
)
raw = load_dataset(HF_DATASET_REPO, token=HF_TOKEN)
label_feature = raw["train"].features["label"]
if hasattr(label_feature, "names"):
_CLASS_NAMES = label_feature.names
else:
_CLASS_NAMES = sorted(list(set(raw["train"]["label"])))
if "test" not in raw:
try:
split = raw["train"].train_test_split(
test_size=0.2,
seed=RANDOM_SEED,
stratify_by_column="label",
)
except Exception:
split = raw["train"].train_test_split(
test_size=0.2,
seed=RANDOM_SEED,
)
raw = {
"train": split["train"],
"test": split["test"],
}
_HF_DATASET_CACHE = raw
return _HF_DATASET_CACHE, _CLASS_NAMES
def get_class_names() -> List[str]:
_, class_names = load_charcoal_dataset()
return class_names
def make_loaders(batch_size: int, val_ratio: float = 0.1):
raw, class_names = load_charcoal_dataset()
transform = get_transform()
train_dataset = HFDatasetWrapper(raw["train"], transform)
test_dataset = HFDatasetWrapper(raw["test"], transform)
val_size = int(len(train_dataset) * val_ratio)
train_size = len(train_dataset) - val_size
train_subset, val_subset = random_split(
train_dataset,
[train_size, val_size],
generator=torch.Generator().manual_seed(RANDOM_SEED),
)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, test_loader, class_names