Cataract-ViT / src /dataset.py
Decoder24's picture
Upload folder using huggingface_hub
a080b32 verified
import os
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms
import numpy as np
from collections import Counter
def get_dataloaders(data_dir, batch_size=32, image_size=224, num_workers=4):
# Augmentasi training
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=(45)),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Augmentasi validasi lebih ringan
valid_transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# Load dataset
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transform)
valid_dataset = datasets.ImageFolder(os.path.join(data_dir, "valid"), transform=valid_transform)
# Hitung distribusi class untuk WeightedRandomSampler
class_counts = Counter([label for _, label in train_dataset.samples])
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for _, label in train_dataset.samples]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, pin_memory=True)
return train_loader, valid_loader, train_dataset.classes, train_dataset