Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from torch.utils.data.sampler import WeightedRandomSampler | |
| from .datasets import dataset_folder | |
| import os | |
| def get_dataset(opt): | |
| classes = os.listdir(opt.dataroot) if len(opt.classes) == 0 else opt.classes | |
| if '0_real' not in classes or '1_fake' not in classes: | |
| dset_lst = [] | |
| for cls in classes: | |
| root = opt.dataroot + '/' + cls | |
| dset = dataset_folder(opt, root) | |
| dset_lst.append(dset) | |
| return torch.utils.data.ConcatDataset(dset_lst) | |
| return dataset_folder(opt, opt.dataroot) | |
| def get_bal_sampler(dataset): | |
| targets = [] | |
| for d in dataset.datasets: | |
| targets.extend(d.targets) | |
| ratio = np.bincount(targets) | |
| w = 1. / torch.tensor(ratio, dtype=torch.float) | |
| sample_weights = w[targets] | |
| sampler = WeightedRandomSampler(weights=sample_weights, | |
| num_samples=len(sample_weights)) | |
| return sampler | |
| def create_dataloader(opt): | |
| shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
| dataset = get_dataset(opt) | |
| sampler = get_bal_sampler(dataset) if opt.class_bal else None | |
| data_loader = torch.utils.data.DataLoader(dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=shuffle, | |
| sampler=sampler, | |
| drop_last=True if opt.isTrain else False, | |
| num_workers=int(opt.num_threads)) | |
| return data_loader | |