|
|
""" |
|
|
HICODet dataset under PyTorch framework |
|
|
|
|
|
Fred Zhang <frederic.zhang@anu.edu.au> |
|
|
|
|
|
The Australian National University |
|
|
Australian Centre for Robotic Vision |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
|
|
|
from typing import Optional, List, Callable, Tuple |
|
|
from pocket.data import ImageDataset, DataSubset |
|
|
|
|
|
class HICODetSubset(DataSubset): |
|
|
def __init__(self, *args) -> None: |
|
|
super().__init__(*args) |
|
|
def filename(self, idx: int) -> str: |
|
|
"""Override: return the image file name in the subset""" |
|
|
return self._filenames[self._idx[self.pool[idx]]] |
|
|
def image_size(self, idx: int) -> Tuple[int, int]: |
|
|
"""Override: return the size (width, height) of an image in the subset""" |
|
|
return self._image_sizes[self._idx[self.pool[idx]]] |
|
|
@property |
|
|
def anno_interaction(self) -> List[int]: |
|
|
"""Override: Number of annotated box pairs for each interaction class""" |
|
|
num_anno = [0 for _ in range(self.num_interation_cls)] |
|
|
intra_idx = [self._idx[i] for i in self.pool] |
|
|
for idx in intra_idx: |
|
|
for hoi in self._anno[idx]['hoi']: |
|
|
num_anno[hoi] += 1 |
|
|
return num_anno |
|
|
@property |
|
|
def anno_object(self) -> List[int]: |
|
|
"""Override: Number of annotated box pairs for each object class""" |
|
|
num_anno = [0 for _ in range(self.num_object_cls)] |
|
|
anno_interaction = self.anno_interaction |
|
|
for corr in self._class_corr: |
|
|
num_anno[corr[1]] += anno_interaction[corr[0]] |
|
|
return num_anno |
|
|
@property |
|
|
def anno_action(self) -> List[int]: |
|
|
"""Override: Number of annotated box pairs for each action class""" |
|
|
num_anno = [0 for _ in range(self.num_action_cls)] |
|
|
anno_interaction = self.anno_interaction |
|
|
for corr in self._class_corr: |
|
|
num_anno[corr[2]] += anno_interaction[corr[0]] |
|
|
return num_anno |
|
|
|
|
|
class HICODet(ImageDataset): |
|
|
""" |
|
|
Arguments: |
|
|
root(str): Root directory where images are downloaded to |
|
|
anno_file(str): Path to json annotation file |
|
|
transform(callable, optional): A function/transform that takes in an PIL image |
|
|
and returns a transformed version |
|
|
target_transform(callable, optional): A function/transform that takes in the |
|
|
target and transforms it |
|
|
transforms (callable, optional): A function/transform that takes input sample |
|
|
and its target as entry and returns a transformed version. |
|
|
""" |
|
|
def __init__(self, root: str, anno_file: str, |
|
|
transform: Optional[Callable] = None, |
|
|
target_transform: Optional[Callable] = None, |
|
|
transforms: Optional[Callable] = None) -> None: |
|
|
super(HICODet, self).__init__(root, transform, target_transform, transforms) |
|
|
with open(anno_file, 'r') as f: |
|
|
anno = json.load(f) |
|
|
|
|
|
import pdb;pdb.set_trace() |
|
|
self.num_object_cls = 80 |
|
|
self.num_interation_cls = 600 |
|
|
self.num_action_cls = 117 |
|
|
self._anno_file = anno_file |
|
|
|
|
|
|
|
|
self._load_annotation_and_metadata(anno) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return the number of images""" |
|
|
return len(self._idx) |
|
|
|
|
|
def __getitem__(self, i: int) -> tuple: |
|
|
""" |
|
|
Arguments: |
|
|
i(int): Index to an image |
|
|
|
|
|
Returns: |
|
|
tuple[image, target]: By default, the tuple consists of a PIL image and a |
|
|
dict with the following keys: |
|
|
"boxes_h": list[list[4]] |
|
|
"boxes_o": list[list[4]] |
|
|
"hoi":: list[N] |
|
|
"verb": list[N] |
|
|
"object": list[N] |
|
|
""" |
|
|
intra_idx = self._idx[i] |
|
|
return self._transforms( |
|
|
self.load_image(os.path.join(self._root, self._filenames[intra_idx])), |
|
|
self._anno[intra_idx] |
|
|
) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
"""Return the executable string representation""" |
|
|
reprstr = self.__class__.__name__ + '(root=' + repr(self._root) |
|
|
reprstr += ', anno_file=' |
|
|
reprstr += repr(self._anno_file) |
|
|
reprstr += ')' |
|
|
|
|
|
return reprstr |
|
|
|
|
|
def __str__(self) -> str: |
|
|
"""Return the readable string representation""" |
|
|
reprstr = 'Dataset: ' + self.__class__.__name__ + '\n' |
|
|
reprstr += '\tNumber of images: {}\n'.format(self.__len__()) |
|
|
reprstr += '\tImage directory: {}\n'.format(self._root) |
|
|
reprstr += '\tAnnotation file: {}\n'.format(self._root) |
|
|
return reprstr |
|
|
|
|
|
@property |
|
|
def annotations(self) -> List[dict]: |
|
|
return self._anno |
|
|
|
|
|
@property |
|
|
def class_corr(self) -> List[Tuple[int, int, int]]: |
|
|
""" |
|
|
Class correspondence matrix in zero-based index |
|
|
[ |
|
|
[hoi_idx, obj_idx, verb_idx], |
|
|
... |
|
|
] |
|
|
|
|
|
Returns: |
|
|
list[list[3]] |
|
|
""" |
|
|
return self._class_corr.copy() |
|
|
|
|
|
@property |
|
|
def object_n_verb_to_interaction(self) -> List[list]: |
|
|
""" |
|
|
The interaction classes corresponding to an object-verb pair |
|
|
|
|
|
HICODet.object_n_verb_to_interaction[obj_idx][verb_idx] gives interaction class |
|
|
index if the pair is valid, None otherwise |
|
|
|
|
|
Returns: |
|
|
list[list[117]] |
|
|
""" |
|
|
lut = np.full([self.num_object_cls, self.num_action_cls], None) |
|
|
for i, j, k in self._class_corr: |
|
|
lut[j, k] = i |
|
|
return lut.tolist() |
|
|
|
|
|
@property |
|
|
def object_to_interaction(self) -> List[list]: |
|
|
""" |
|
|
The interaction classes that involve each object type |
|
|
|
|
|
Returns: |
|
|
list[list] |
|
|
""" |
|
|
obj_to_int = [[] for _ in range(self.num_object_cls)] |
|
|
for corr in self._class_corr: |
|
|
obj_to_int[corr[1]].append(corr[0]) |
|
|
return obj_to_int |
|
|
|
|
|
@property |
|
|
def object_to_verb(self) -> List[list]: |
|
|
""" |
|
|
The valid verbs for each object type |
|
|
|
|
|
Returns: |
|
|
list[list] |
|
|
""" |
|
|
obj_to_verb = [[] for _ in range(self.num_object_cls)] |
|
|
for corr in self._class_corr: |
|
|
obj_to_verb[corr[1]].append(corr[2]) |
|
|
return obj_to_verb |
|
|
|
|
|
@property |
|
|
def anno_interaction(self) -> List[int]: |
|
|
""" |
|
|
Number of annotated box pairs for each interaction class |
|
|
|
|
|
Returns: |
|
|
list[600] |
|
|
""" |
|
|
return self._num_anno.copy() |
|
|
|
|
|
@property |
|
|
def anno_object(self) -> List[int]: |
|
|
""" |
|
|
Number of annotated box pairs for each object class |
|
|
|
|
|
Returns: |
|
|
list[80] |
|
|
""" |
|
|
num_anno = [0 for _ in range(self.num_object_cls)] |
|
|
for corr in self._class_corr: |
|
|
num_anno[corr[1]] += self._num_anno[corr[0]] |
|
|
return num_anno |
|
|
|
|
|
@property |
|
|
def anno_action(self) -> List[int]: |
|
|
""" |
|
|
Number of annotated box pairs for each action class |
|
|
|
|
|
Returns: |
|
|
list[117] |
|
|
""" |
|
|
num_anno = [0 for _ in range(self.num_action_cls)] |
|
|
for corr in self._class_corr: |
|
|
num_anno[corr[2]] += self._num_anno[corr[0]] |
|
|
return num_anno |
|
|
|
|
|
@property |
|
|
def objects(self) -> List[str]: |
|
|
""" |
|
|
Object names |
|
|
|
|
|
Returns: |
|
|
list[str] |
|
|
""" |
|
|
return self._objects.copy() |
|
|
|
|
|
@property |
|
|
def verbs(self) -> List[str]: |
|
|
""" |
|
|
Verb (action) names |
|
|
|
|
|
Returns: |
|
|
list[str] |
|
|
""" |
|
|
return self._verbs.copy() |
|
|
|
|
|
@property |
|
|
def interactions(self) -> List[str]: |
|
|
""" |
|
|
Combination of verbs and objects |
|
|
|
|
|
Returns: |
|
|
list[str] |
|
|
""" |
|
|
return [self._verbs[j] + ' ' + self.objects[i] |
|
|
for _, i, j in self._class_corr] |
|
|
|
|
|
def split(self, ratio: float) -> Tuple[HICODetSubset, HICODetSubset]: |
|
|
""" |
|
|
Split the dataset according to given ratio |
|
|
|
|
|
Arguments: |
|
|
ratio(float): The percentage of training set between 0 and 1 |
|
|
Returns: |
|
|
train(Dataset) |
|
|
val(Dataset) |
|
|
""" |
|
|
perm = np.random.permutation(len(self._idx)) |
|
|
n = int(len(perm) * ratio) |
|
|
return HICODetSubset(self, perm[:n]), HICODetSubset(self, perm[n:]) |
|
|
|
|
|
def filename(self, idx: int) -> str: |
|
|
"""Return the image file name given the index""" |
|
|
return self._filenames[self._idx[idx]] |
|
|
|
|
|
def image_size(self, idx: int) -> Tuple[int, int]: |
|
|
"""Return the size (width, height) of an image""" |
|
|
return self._image_sizes[self._idx[idx]] |
|
|
|
|
|
def _load_annotation_and_metadata(self, f: dict) -> None: |
|
|
""" |
|
|
Arguments: |
|
|
f(dict): Dictionary loaded from {anno_file}.json |
|
|
""" |
|
|
idx = list(range(len(f['filenames']))) |
|
|
for empty_idx in f['empty']: |
|
|
idx.remove(empty_idx) |
|
|
|
|
|
num_anno = [0 for _ in range(self.num_interation_cls)] |
|
|
for anno in f['annotation']: |
|
|
for hoi in anno['hoi']: |
|
|
num_anno[hoi] += 1 |
|
|
|
|
|
self._idx = idx |
|
|
self._num_anno = num_anno |
|
|
|
|
|
self._anno = f['annotation'] |
|
|
self._filenames = f['filenames'] |
|
|
self._image_sizes = f['size'] |
|
|
self._class_corr = f['correspondence'] |
|
|
self._empty_idx = f['empty'] |
|
|
self._objects = f['objects'] |
|
|
self._verbs = f['verbs'] |