Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # HOTR official code : hotr/data/datasets/hico.py | |
| # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic) | |
| # Copyright (c) Hitachi, Ltd. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------ | |
| from pathlib import Path | |
| from PIL import Image | |
| import json | |
| from collections import defaultdict | |
| import numpy as np | |
| import torch | |
| import torch.utils.data | |
| import torchvision | |
| from hotr.data.datasets import builtin_meta | |
| import hotr.data.transforms.transforms as T | |
| class HICODetection(torch.utils.data.Dataset): | |
| def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries): | |
| self.img_set = img_set | |
| self.img_folder = img_folder | |
| with open(anno_file, 'r') as f: | |
| self.annotations = json.load(f) | |
| with open(action_list_file, 'r') as f: | |
| self.action_lines = f.readlines() | |
| self._transforms = transforms | |
| self.num_queries = num_queries | |
| self.get_metadata() | |
| if img_set == 'train': | |
| self.ids = [] | |
| for idx, img_anno in enumerate(self.annotations): | |
| for hoi in img_anno['hoi_annotation']: | |
| if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']): | |
| break | |
| else: | |
| self.ids.append(idx) | |
| else: | |
| self.ids = list(range(len(self.annotations))) | |
| ############################################################################ | |
| # Number Method | |
| ############################################################################ | |
| def get_metadata(self): | |
| meta = builtin_meta._get_coco_instances_meta() | |
| self.COCO_CLASSES = meta['coco_classes'] | |
| self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()] | |
| self._valid_verb_ids, self._valid_verb_names = [], [] | |
| for action_line in self.action_lines[2:]: | |
| act_id, act_name = action_line.split() | |
| self._valid_verb_ids.append(int(act_id)) | |
| self._valid_verb_names.append(act_name) | |
| def get_valid_obj_ids(self): | |
| return self._valid_obj_ids | |
| def get_actions(self): | |
| return self._valid_verb_names | |
| def num_category(self): | |
| return len(self.COCO_CLASSES) | |
| def num_action(self): | |
| return len(self._valid_verb_ids) | |
| ############################################################################ | |
| def __len__(self): | |
| return len(self.ids) | |
| def __getitem__(self, idx): | |
| img_anno = self.annotations[self.ids[idx]] | |
| img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') | |
| w, h = img.size | |
| # cut out the GTs that exceed the number of object queries | |
| if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: | |
| img_anno['annotations'] = img_anno['annotations'][:self.num_queries] | |
| boxes = [obj['bbox'] for obj in img_anno['annotations']] | |
| # guard against no boxes via resizing | |
| boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
| if self.img_set == 'train': | |
| # Add index for confirming which boxes are kept after image transformation | |
| classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] | |
| else: | |
| classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] | |
| classes = torch.tensor(classes, dtype=torch.int64) | |
| target = {} | |
| target['orig_size'] = torch.as_tensor([int(h), int(w)]) | |
| target['size'] = torch.as_tensor([int(h), int(w)]) | |
| if self.img_set == 'train': | |
| boxes[:, 0::2].clamp_(min=0, max=w) | |
| boxes[:, 1::2].clamp_(min=0, max=h) | |
| keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
| boxes = boxes[keep] | |
| classes = classes[keep] | |
| target['boxes'] = boxes | |
| target['labels'] = classes | |
| target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) | |
| target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| kept_box_indices = [label[0] for label in target['labels']] | |
| target['labels'] = target['labels'][:, 1] | |
| obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] | |
| sub_obj_pairs = [] | |
| for hoi in img_anno['hoi_annotation']: | |
| if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: | |
| continue | |
| sub_obj_pair = (hoi['subject_id'], hoi['object_id']) | |
| if sub_obj_pair in sub_obj_pairs: | |
| verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 | |
| else: | |
| sub_obj_pairs.append(sub_obj_pair) | |
| obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) | |
| verb_label = [0 for _ in range(len(self._valid_verb_ids))] | |
| verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 | |
| sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] | |
| obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] | |
| verb_labels.append(verb_label) | |
| sub_boxes.append(sub_box) | |
| obj_boxes.append(obj_box) | |
| if len(sub_obj_pairs) == 0: | |
| target['pair_targets'] = torch.zeros((0,), dtype=torch.int64) | |
| target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) | |
| target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) | |
| target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) | |
| else: | |
| target['pair_targets'] = torch.stack(obj_labels) | |
| target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32) | |
| target['sub_boxes'] = torch.stack(sub_boxes) | |
| target['obj_boxes'] = torch.stack(obj_boxes) | |
| else: | |
| target['boxes'] = boxes | |
| target['labels'] = classes | |
| target['id'] = idx | |
| if self._transforms is not None: | |
| img, _ = self._transforms(img, None) | |
| hois = [] | |
| for hoi in img_anno['hoi_annotation']: | |
| hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) | |
| target['hois'] = torch.as_tensor(hois, dtype=torch.int64) | |
| return img, target | |
| def set_rare_hois(self, anno_file): | |
| with open(anno_file, 'r') as f: | |
| annotations = json.load(f) | |
| counts = defaultdict(lambda: 0) | |
| for img_anno in annotations: | |
| hois = img_anno['hoi_annotation'] | |
| bboxes = img_anno['annotations'] | |
| for hoi in hois: | |
| triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), | |
| self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), | |
| self._valid_verb_ids.index(hoi['category_id'])) | |
| counts[triplet] += 1 | |
| self.rare_triplets = [] | |
| self.non_rare_triplets = [] | |
| for triplet, count in counts.items(): | |
| if count < 10: | |
| self.rare_triplets.append(triplet) | |
| else: | |
| self.non_rare_triplets.append(triplet) | |
| def load_correct_mat(self, path): | |
| self.correct_mat = np.load(path) | |
| # Add color jitter to coco transforms | |
| def make_hico_transforms(image_set): | |
| normalize = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] | |
| if image_set == 'train': | |
| return T.Compose([ | |
| T.RandomHorizontalFlip(), | |
| T.ColorJitter(.4, .4, .4), | |
| T.RandomSelect( | |
| T.RandomResize(scales, max_size=1333), | |
| T.Compose([ | |
| T.RandomResize([400, 500, 600]), | |
| T.RandomSizeCrop(384, 600), | |
| T.RandomResize(scales, max_size=1333), | |
| ]) | |
| ), | |
| normalize, | |
| ]) | |
| if image_set == 'val': | |
| return T.Compose([ | |
| T.RandomResize([800], max_size=1333), | |
| normalize, | |
| ]) | |
| if image_set == 'test': | |
| return T.Compose([ | |
| T.RandomResize([800], max_size=1333), | |
| normalize, | |
| ]) | |
| raise ValueError(f'unknown {image_set}') | |
| def build(image_set, args): | |
| root = Path(args.data_path) | |
| assert root.exists(), f'provided HOI path {root} does not exist' | |
| PATHS = { | |
| 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'), | |
| 'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'), | |
| 'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json') | |
| } | |
| CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' | |
| action_list_file = root / 'list_action.txt' | |
| img_folder, anno_file = PATHS[image_set] | |
| dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set), | |
| num_queries=args.num_queries) | |
| if image_set == 'val' or image_set == 'test': | |
| dataset.set_rare_hois(PATHS['train'][1]) | |
| dataset.load_correct_mat(CORRECT_MAT_PATH) | |
| return dataset |