xmutly's picture
Upload 294 files
e1aaaac verified
raw
history blame contribute delete
889 Bytes
import torchvision
"""
BabelImageNet from https://arxiv.org/pdf/2306.08658.pdf
Adapted from https://github.com/gregor-ge/Babel-ImageNet, thanks to the authors
"""
class BabelImageNet(torchvision.datasets.ImageNet):
def __init__(self, root: str, idxs, split: str = "val", download=None, **kwargs) -> None:
super().__init__(root, split, **kwargs)
examples_per_class = len(self.targets) // 1000
select_idxs = [idx*examples_per_class + i for idx in idxs for i in range(examples_per_class)]
self.targets = [i for i in range(len(idxs)) for _ in range(examples_per_class)]
self.imgs = [self.imgs[i] for i in select_idxs]
self.samples = [self.samples[i] for i in select_idxs]
self.idxs = idxs
def __getitem__(self, i):
img, target = super().__getitem__(i)
target = self.idxs.index(target)
return img, target