r""" Dataloader builder for few-shot semantic segmentation dataset """ from torch.utils.data.distributed import DistributedSampler as Sampler from torch.utils.data import DataLoader from torchvision import transforms from data.pascal import DatasetPASCAL from data.coco import DatasetCOCO from data.fss import DatasetFSS from data.deepglobe import DatasetDeepglobe from data.isic import DatasetISIC from data.lung import DatasetLung from data.fss import DatasetFSS from data.suim import DatasetSUIM class FSSDataset: @classmethod def initialize(cls, img_size, datapath): cls.datasets = { 'pascal': DatasetPASCAL, 'coco': DatasetCOCO, 'fss': DatasetFSS, 'deepglobe': DatasetDeepglobe, 'isic': DatasetISIC, 'lung': DatasetLung, 'suim': DatasetSUIM } cls.img_mean = [0.485, 0.456, 0.406] cls.img_std = [0.229, 0.224, 0.225] cls.datapath = datapath cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)), transforms.ToTensor(), transforms.Normalize(cls.img_mean, cls.img_std)]) @classmethod def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1): nworker = nworker if split == 'trn' else 0 dataset = cls.datasets[benchmark](cls.datapath, fold=fold, transform=cls.transform, split=split, shot=shot) # Force randomness during training for diverse episode combinations # Freeze randomness during testing for reproducibility #train_sampler = Sampler(dataset) if split == 'trn' else None dataloader = DataLoader(dataset, batch_size=bsz, shuffle=split=='trn', num_workers=nworker, pin_memory=True) return dataloader