Spaces:
Running on Zero
Running on Zero
File size: 1,465 Bytes
52d0a0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | from torch.utils.data import Dataset
from .COCO import Coco
from .cityscapes import Cityscapes
from .sa1b import SA1B
# from featup.datasets.nyu_probe import NYU
class SlicedDataset(Dataset):
def __init__(self, ds, start, end):
self.ds = ds
self.start = max(0, start)
self.end = min(len(ds), end)
def __getitem__(self, index):
if index >= self.__len__():
raise StopIteration
return self.ds[self.start + index]
def __len__(self):
return self.end - self.start
class SingleImageDataset(Dataset):
def __init__(self, i, ds, l=None):
self.ds = ds
self.i = i
self.l = len(self.ds) if l is None else l
def __len__(self):
return self.l
def __getitem__(self, item):
return self.ds[self.i]
def get_dataset(dataroot, name, split, transform, target_transform, include_labels, sample_size=100000):
if name == 'cocostuff':
return Coco(dataroot, split, transform, target_transform, include_labels=include_labels)
elif name == 'cityscapes':
return Cityscapes(dataroot, split, transform, target_transform, include_labels=include_labels)
elif name == "sa1b":
return SA1B(
root=dataroot,
split=split,
transform=transform,
target_transform=target_transform,
sample_size=sample_size,
)
else:
raise ValueError(f"Unknown dataset {name}")
|