Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| import torch.utils | |
| import torch.utils.data | |
| from torch.utils.data.sampler import WeightedRandomSampler | |
| import torch.distributed as dist | |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize | |
| from .datasets import RealFakeDataset | |
| def get_bal_sampler(dataset): | |
| targets = [] | |
| for d in dataset.datasets: | |
| targets.extend(d.targets) | |
| ratio = np.bincount(targets) | |
| w = 1. / torch.tensor(ratio, dtype=torch.float) | |
| sample_weights = w[targets] | |
| sampler = WeightedRandomSampler(weights=sample_weights, | |
| num_samples=len(sample_weights)) | |
| return sampler | |
| def create_train_val_dataloader(opt, clip_model, transform, k_split: float): | |
| shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
| dataset = RealFakeDataset(opt, clip_model, transform) | |
| # ๅๅ่ฎญ็ป้ๅ้ช่ฏ้ | |
| dataset_size = len(dataset) | |
| train_size = int(dataset_size * k_split) | |
| val_size = dataset_size - train_size | |
| train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size]) | |
| train_loader = torch.utils.data.DataLoader(train_dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=False, | |
| num_workers=16 | |
| ) | |
| val_loader = torch.utils.data.DataLoader(val_dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=False, | |
| num_workers=16 | |
| ) | |
| return train_loader, val_loader | |
| def create_test_dataloader(opt, clip_model, transform): | |
| shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False | |
| dataset = RealFakeDataset(opt, clip_model, transform) | |
| sampler = get_bal_sampler(dataset) if opt.class_bal else None | |
| data_loader = torch.utils.data.DataLoader(dataset, | |
| batch_size=opt.batch_size, | |
| shuffle=shuffle, | |
| sampler=sampler, | |
| num_workers=16 | |
| ) | |
| return data_loader | |