Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from torch.utils.data.sampler import WeightedRandomSampler | |
| from .datasets import RealFakeDataset | |
| def get_bal_sampler(dataset): | |
| targets = [] | |
| for d in dataset.datasets: | |
| targets.extend(d.targets) | |
| ratio = np.bincount(targets) | |
| w = 1. / torch.tensor(ratio, dtype=torch.float) | |
| sample_weights = w[targets] | |
| sampler = WeightedRandomSampler(weights=sample_weights, | |
| num_samples=len(sample_weights)) | |
| return sampler | |
| def create_dataloader(opt, preprocess=None): | |
| shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
| dataset = RealFakeDataset(opt) | |
| print(len(dataset)) | |
| if '2b' in opt.arch: | |
| dataset.transform = preprocess | |
| sampler = get_bal_sampler(dataset) if opt.class_bal else None | |
| data_loader = torch.utils.data.DataLoader(dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=shuffle, | |
| sampler=sampler, | |
| num_workers=int(opt.num_threads)) | |
| return data_loader | |