""" Code adapted from https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/objectnet.py Thanks to the authors of wise-ft """ import os import json from pathlib import Path import PIL import numpy as np import torch from torchvision import datasets from torchvision.transforms import Compose from pathlib import Path def get_metadata(folder): metadata = Path(folder) with open(metadata / 'folder_to_objectnet_label.json', 'r') as f: folder_map = json.load(f) folder_map = {v: k for k, v in folder_map.items()} with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f: objectnet_map = json.load(f) with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f: pytorch_map = json.load(f) pytorch_map = {v: k for k, v in pytorch_map.items()} with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f: imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)} folder_to_ids, class_sublist = {}, [] classnames = [] for objectnet_name, imagenet_names in objectnet_map.items(): imagenet_names = imagenet_names.split('; ') imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names] class_sublist.extend(imagenet_ids) folder_to_ids[folder_map[objectnet_name]] = imagenet_ids class_sublist = sorted(class_sublist) class_sublist_mask = [(i in class_sublist) for i in range(1000)] classname_map = {v: k for k, v in folder_map.items()} return class_sublist, class_sublist_mask, folder_to_ids, classname_map class ObjectNetDataset(datasets.ImageFolder): def __init__(self, root, transform): (self._class_sublist, self.class_sublist_mask, self.folders_to_ids, self.classname_map) = get_metadata(root) subdir = os.path.join(root, "objectnet-1.0", "images") label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))} self.label_map = label_map super().__init__(subdir, transform=transform) self.samples = [ d for d in self.samples if os.path.basename(os.path.dirname(d[0])) in self.label_map ] self.imgs = self.samples self.classes = sorted(list(self.folders_to_ids.keys())) self.classes = [self.classname_map[c].lower() for c in self.classes] def __len__(self): return len(self.samples) def __getitem__(self, index): path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) label = os.path.basename(os.path.dirname(path)) return sample, self.label_map[label]