|
|
import os
|
|
|
from torch.utils.data import DataLoader, WeightedRandomSampler
|
|
|
from torchvision import datasets, transforms
|
|
|
import numpy as np
|
|
|
from collections import Counter
|
|
|
|
|
|
def get_dataloaders(data_dir, batch_size=32, image_size=224, num_workers=4):
|
|
|
|
|
|
train_transform = transforms.Compose([
|
|
|
transforms.RandomHorizontalFlip(),
|
|
|
transforms.RandomVerticalFlip(),
|
|
|
transforms.RandomRotation(degrees=(45)),
|
|
|
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
|
|
transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225]),
|
|
|
])
|
|
|
|
|
|
|
|
|
valid_transform = transforms.Compose([
|
|
|
transforms.Resize((image_size, image_size)),
|
|
|
transforms.ToTensor(),
|
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
std=[0.229, 0.224, 0.225]),
|
|
|
])
|
|
|
|
|
|
|
|
|
train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_transform)
|
|
|
valid_dataset = datasets.ImageFolder(os.path.join(data_dir, "valid"), transform=valid_transform)
|
|
|
|
|
|
|
|
|
class_counts = Counter([label for _, label in train_dataset.samples])
|
|
|
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
|
|
|
sample_weights = [class_weights[label] for _, label in train_dataset.samples]
|
|
|
|
|
|
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
|
|
|
|
|
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler,
|
|
|
num_workers=num_workers, pin_memory=True)
|
|
|
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,
|
|
|
num_workers=num_workers, pin_memory=True)
|
|
|
|
|
|
return train_loader, valid_loader, train_dataset.classes, train_dataset
|
|
|
|