| import torchvision |
|
|
| """ |
| BabelImageNet from https://arxiv.org/pdf/2306.08658.pdf |
| Adapted from https://github.com/gregor-ge/Babel-ImageNet, thanks to the authors |
| """ |
| class BabelImageNet(torchvision.datasets.ImageNet): |
| def __init__(self, root: str, idxs, split: str = "val", download=None, **kwargs) -> None: |
| super().__init__(root, split, **kwargs) |
| examples_per_class = len(self.targets) // 1000 |
| select_idxs = [idx*examples_per_class + i for idx in idxs for i in range(examples_per_class)] |
| self.targets = [i for i in range(len(idxs)) for _ in range(examples_per_class)] |
| self.imgs = [self.imgs[i] for i in select_idxs] |
| self.samples = [self.samples[i] for i in select_idxs] |
| self.idxs = idxs |
|
|
| def __getitem__(self, i): |
| img, target = super().__getitem__(i) |
| target = self.idxs.index(target) |
| return img, target |