File size: 1,811 Bytes
97fcc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
import datasets
from torch.utils.data import DataLoader, WeightedRandomSampler, BatchSampler

from src.DataLoader.plantvillage_dataset import PlantVillageDataset
from src.DataLoader.utils import calc_class_dist

def create_dataloader(dataset: datasets.Dataset, batch_size: int, samples_per_epoch: int, is_training_set: bool = True) -> DataLoader:
    """

    Creates a new torch dataloader using given dataset and parameters.



    Args:

        dataset (datasets.Dataset): Dataset loaded using huggingface datasets library.

        batch_size (int): Number of examples to sample per batch.

        samples_per_epoch (int): Total number of examples to sample in one epoch.

        is_training_set (bool): decides whether the given dataset should provide augmented images (Default is True).



    Returns:

        torch.utils.data.Dataloader: Returns a newly created dataloader.



    Example:

        `loader = create_dataloader(train_ds, 32, 1000)`

    """
    
    # Assign creation of torch-compatible dataset from hf dataset
    torch_dataset = PlantVillageDataset(dataset, is_training_set)

    # retrieve class percentages
    class_percent = torch.tensor(calc_class_dist(dataset))
    
    # calculate weight per class
    class_weights = 1.0 / class_percent
    class_weights = class_weights / class_weights.sum()

    # assign class weights to each sample for the weighted sampler 
    labels = torch.tensor(dataset['label'])
    sample_weights = class_weights[labels]
    

    # create sampler and dataloader
    sampler = WeightedRandomSampler(sample_weights, replacement = True, num_samples=samples_per_epoch)
    loader = DataLoader(torch_dataset, batch_sampler=BatchSampler(sampler, batch_size=batch_size, drop_last=True))

    return loader