ABCDFSS / data /dataset.py
heyoujue's picture
add submission code
322161a
r""" Dataloader builder for few-shot semantic segmentation dataset """
from torch.utils.data.distributed import DistributedSampler as Sampler
from torch.utils.data import DataLoader
from torchvision import transforms
from data.pascal import DatasetPASCAL
from data.coco import DatasetCOCO
from data.fss import DatasetFSS
from data.deepglobe import DatasetDeepglobe
from data.isic import DatasetISIC
from data.lung import DatasetLung
from data.fss import DatasetFSS
from data.suim import DatasetSUIM
class FSSDataset:
@classmethod
def initialize(cls, img_size, datapath):
cls.datasets = {
'pascal': DatasetPASCAL,
'coco': DatasetCOCO,
'fss': DatasetFSS,
'deepglobe': DatasetDeepglobe,
'isic': DatasetISIC,
'lung': DatasetLung,
'suim': DatasetSUIM
}
cls.img_mean = [0.485, 0.456, 0.406]
cls.img_std = [0.229, 0.224, 0.225]
cls.datapath = datapath
cls.transform = transforms.Compose([transforms.Resize(size=(img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(cls.img_mean, cls.img_std)])
@classmethod
def build_dataloader(cls, benchmark, bsz, nworker, fold, split, shot=1):
nworker = nworker if split == 'trn' else 0
dataset = cls.datasets[benchmark](cls.datapath, fold=fold,
transform=cls.transform,
split=split, shot=shot)
# Force randomness during training for diverse episode combinations
# Freeze randomness during testing for reproducibility
#train_sampler = Sampler(dataset) if split == 'trn' else None
dataloader = DataLoader(dataset, batch_size=bsz, shuffle=split=='trn', num_workers=nworker,
pin_memory=True)
return dataloader