import torch import datasets from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler from src.DataLoader.plantvillage_dataset import PlantVillageDataset from src.DataLoader.utils import calc_class_dist def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader: """ Creates a new torch dataloader using given dataset and parameters. Args: dataset (datasets.Dataset): Dataset loaded using huggingface datasets library. batch_size (int): Number of examples to sample per batch. samples_per_epoch (int): Total number of examples to sample in one epoch. is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True). Returns: torch.utils.data.Dataloader: Returns a newly created dataloader. Example: `loader = create_dataloader(train_ds, 32, 1000)` """ # Assign creation of torch-compatible dataset from hf dataset torch_dataset = PlantVillageDataset(dataset, is_training_set) # retrieve class percentages class_percent = torch.tensor(calc_class_dist(dataset)) # calculate weight per class class_weights = 1.0 / class_percent class_weights = class_weights / class_weights.sum() # assign class weights to each sample for the weighted sampler labels = torch.tensor(dataset['label']) sample_weights = class_weights[labels] # create sampler and dataloader sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch) loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True)) return loader