Spaces:
Sleeping
Sleeping
| import torch | |
| import datasets | |
| from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler | |
| from src.DataLoader.plantvillage_dataset import PlantVillageDataset | |
| from src.DataLoader.utils import calc_class_dist | |
| def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader: | |
| """ | |
| Creates a new torch dataloader using given dataset and parameters. | |
| Args: | |
| dataset (datasets.Dataset): Dataset loaded using huggingface datasets library. | |
| batch_size (int): Number of examples to sample per batch. | |
| samples_per_epoch (int): Total number of examples to sample in one epoch. | |
| is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True). | |
| Returns: | |
| torch.utils.data.Dataloader: Returns a newly created dataloader. | |
| Example: | |
| `loader = create_dataloader(train_ds, 32, 1000)` | |
| """ | |
| # Assign creation of torch-compatible dataset from hf dataset | |
| torch_dataset = PlantVillageDataset(dataset, is_training_set) | |
| # retrieve class percentages | |
| class_percent = torch.tensor(calc_class_dist(dataset)) | |
| # calculate weight per class | |
| class_weights = 1.0 / class_percent | |
| class_weights = class_weights / class_weights.sum() | |
| # assign class weights to each sample for the weighted sampler | |
| labels = torch.tensor(dataset['label']) | |
| sample_weights = class_weights[labels] | |
| # create sampler and dataloader | |
| sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch) | |
| loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True)) | |
| return loader | |