Spaces:
Sleeping
Sleeping
File size: 1,811 Bytes
97fcc90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import torch
import datasets
from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler
from src.DataLoader.plantvillage_dataset import PlantVillageDataset
from src.DataLoader.utils import calc_class_dist
def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader:
"""
Creates a new torch dataloader using given dataset and parameters.
Args:
dataset (datasets.Dataset): Dataset loaded using huggingface datasets library.
batch_size (int): Number of examples to sample per batch.
samples_per_epoch (int): Total number of examples to sample in one epoch.
is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True).
Returns:
torch.utils.data.Dataloader: Returns a newly created dataloader.
Example:
`loader = create_dataloader(train_ds, 32, 1000)`
"""
# Assign creation of torch-compatible dataset from hf dataset
torch_dataset = PlantVillageDataset(dataset, is_training_set)
# retrieve class percentages
class_percent = torch.tensor(calc_class_dist(dataset))
# calculate weight per class
class_weights = 1.0 / class_percent
class_weights = class_weights / class_weights.sum()
# assign class weights to each sample for the weighted sampler
labels = torch.tensor(dataset['label'])
sample_weights = class_weights[labels]
# create sampler and dataloader
sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch)
loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True))
return loader
|