xmutly's picture
Upload 294 files
e1aaaac verified
"""
Code adapted from https://github.com/mlfoundations/wise-ft/blob/master/src/datasets/objectnet.py
Thanks to the authors of wise-ft
"""
import os
import json
from pathlib import Path
import PIL
import numpy as np
import torch
from torchvision import datasets
from torchvision.transforms import Compose
from pathlib import Path
def get_metadata(folder):
metadata = Path(folder)
with open(metadata / 'folder_to_objectnet_label.json', 'r') as f:
folder_map = json.load(f)
folder_map = {v: k for k, v in folder_map.items()}
with open(metadata / 'objectnet_to_imagenet_1k.json', 'r') as f:
objectnet_map = json.load(f)
with open(metadata / 'pytorch_to_imagenet_2012_id.json', 'r') as f:
pytorch_map = json.load(f)
pytorch_map = {v: k for k, v in pytorch_map.items()}
with open(metadata / 'imagenet_to_label_2012_v2', 'r') as f:
imagenet_map = {v.strip(): str(pytorch_map[i]) for i, v in enumerate(f)}
folder_to_ids, class_sublist = {}, []
classnames = []
for objectnet_name, imagenet_names in objectnet_map.items():
imagenet_names = imagenet_names.split('; ')
imagenet_ids = [int(imagenet_map[imagenet_name]) for imagenet_name in imagenet_names]
class_sublist.extend(imagenet_ids)
folder_to_ids[folder_map[objectnet_name]] = imagenet_ids
class_sublist = sorted(class_sublist)
class_sublist_mask = [(i in class_sublist) for i in range(1000)]
classname_map = {v: k for k, v in folder_map.items()}
return class_sublist, class_sublist_mask, folder_to_ids, classname_map
class ObjectNetDataset(datasets.ImageFolder):
def __init__(self, root, transform):
(self._class_sublist,
self.class_sublist_mask,
self.folders_to_ids,
self.classname_map) = get_metadata(root)
subdir = os.path.join(root, "objectnet-1.0", "images")
label_map = {name: idx for idx, name in enumerate(sorted(list(self.folders_to_ids.keys())))}
self.label_map = label_map
super().__init__(subdir, transform=transform)
self.samples = [
d for d in self.samples
if os.path.basename(os.path.dirname(d[0])) in self.label_map
]
self.imgs = self.samples
self.classes = sorted(list(self.folders_to_ids.keys()))
self.classes = [self.classname_map[c].lower() for c in self.classes]
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
label = os.path.basename(os.path.dirname(path))
return sample, self.label_map[label]