Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # HOTR official code : main.py | |
| # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| # Modified from DETR (https://github.com/facebookresearch/detr) | |
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| import torch | |
| import torch.nn.functional as F | |
| import copy | |
| import numpy as np | |
| import itertools | |
| from torch import nn | |
| from hotr.util import box_ops | |
| from hotr.util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized) | |
| class SetCriterion(nn.Module): | |
| """ This class computes the loss for DETR. | |
| The process happens in two steps: | |
| 1) we compute hungarian assignment between ground truth boxes and the outputs of the model | |
| 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) | |
| """ | |
| def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_actions=None, HOI_losses=None, HOI_matcher=None, args=None): | |
| """ Create the criterion. | |
| Parameters: | |
| num_classes: number of object categories, omitting the special no-object category | |
| matcher: module able to compute a matching between targets and proposals | |
| weight_dict: dict containing as key the names of the losses and as values their relative weight. | |
| eos_coef: relative classification weight applied to the no-object category | |
| losses: list of all the losses to be applied. See get_loss for list of available losses. | |
| """ | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.matcher = matcher | |
| self.weight_dict = weight_dict | |
| self.losses = losses | |
| self.eos_coef=eos_coef | |
| self.HOI_losses = HOI_losses | |
| self.HOI_matcher = HOI_matcher | |
| self.use_consis=args.use_consis & len(args.augpath_name)>0 | |
| self.num_path = 1+len(args.augpath_name) | |
| if args: | |
| self.HOI_eos_coef = args.hoi_eos_coef | |
| if args.dataset_file == 'vcoco': | |
| self.invalid_ids = args.invalid_ids | |
| self.valid_ids = np.concatenate((args.valid_ids,[-1]), axis=0) # no interaction | |
| elif args.dataset_file == 'hico-det': | |
| self.invalid_ids = [] | |
| self.valid_ids = list(range(num_actions)) + [-1] | |
| # for targets | |
| self.num_tgt_classes = len(args.valid_obj_ids) | |
| tgt_empty_weight = torch.ones(self.num_tgt_classes + 1) | |
| tgt_empty_weight[-1] = self.HOI_eos_coef | |
| self.register_buffer('tgt_empty_weight', tgt_empty_weight) | |
| self.dataset_file = args.dataset_file | |
| empty_weight = torch.ones(self.num_classes + 1) | |
| empty_weight[-1] = eos_coef | |
| self.register_buffer('empty_weight', empty_weight) | |
| ####################################################################################################################### | |
| # * DETR Losses | |
| ####################################################################################################################### | |
| def loss_labels(self, outputs, targets, indices, num_boxes, log=True): | |
| """Classification loss (NLL) | |
| targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] | |
| """ | |
| assert 'pred_logits' in outputs | |
| src_logits = outputs['pred_logits'] | |
| idx = self._get_src_permutation_idx(indices) | |
| target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) | |
| target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device) | |
| target_classes[idx] = target_classes_o | |
| loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) | |
| losses = {'loss_ce': loss_ce} | |
| if log: | |
| # TODO this should probably be a separate loss, not hacked in this one here | |
| losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] | |
| return losses | |
| def loss_cardinality(self, outputs, targets, indices, num_boxes): | |
| """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes | |
| This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients | |
| """ | |
| pred_logits = outputs['pred_logits'] | |
| device = pred_logits.device | |
| tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) | |
| # Count the number of predictions that are NOT "no-object" (which is the last class) | |
| card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) | |
| card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) | |
| losses = {'cardinality_error': card_err} | |
| return losses | |
| def loss_boxes(self, outputs, targets, indices, num_boxes): | |
| """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss | |
| targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] | |
| The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. | |
| """ | |
| assert 'pred_boxes' in outputs | |
| idx = self._get_src_permutation_idx(indices) | |
| src_boxes = outputs['pred_boxes'][idx] | |
| target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) | |
| loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') | |
| losses = {} | |
| losses['loss_bbox'] = loss_bbox.sum() / num_boxes | |
| loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( | |
| box_ops.box_cxcywh_to_xyxy(src_boxes), | |
| box_ops.box_cxcywh_to_xyxy(target_boxes))) | |
| losses['loss_giou'] = loss_giou.sum() / num_boxes | |
| return losses | |
| ####################################################################################################################### | |
| # * HOTR Losses | |
| ####################################################################################################################### | |
| # >>> HOI Losses 1 : HO Pointer | |
| def loss_pair_labels(self, outputs, targets, hoi_indices, num_boxes,use_consis, log=False): | |
| assert ('pred_hidx' in outputs and 'pred_oidx' in outputs) | |
| outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} | |
| nu,q,hd=outputs['pred_hidx'].shape | |
| src_hidx = outputs['pred_hidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1) | |
| src_oidx = outputs['pred_oidx'].view(self.num_path,nu//self.num_path,q,-1).transpose(0,1).flatten(0,1) | |
| hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
| idx = self._get_src_permutation_idx(hoi_ind) | |
| target_hidx_classes = torch.full(src_hidx.shape[:2], -1, dtype=torch.int64, device=src_hidx.device) | |
| target_oidx_classes = torch.full(src_oidx.shape[:2], -1, dtype=torch.int64, device=src_oidx.device) | |
| # H Pointer loss | |
| target_classes_h = torch.cat([t["h_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
| target_hidx_classes[idx] = target_classes_h | |
| # O Pointer loss | |
| target_classes_o = torch.cat([t["o_labels"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
| target_oidx_classes[idx] = target_classes_o | |
| loss_h = F.cross_entropy(src_hidx.transpose(1, 2), target_hidx_classes, ignore_index=-1) | |
| loss_o = F.cross_entropy(src_oidx.transpose(1, 2), target_oidx_classes, ignore_index=-1) | |
| #Consistency loss | |
| if use_consis: | |
| consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices ] | |
| src_hidx_inputs=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| src_hidx_targets=[F.softmax(src_hidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| src_oidx_inputs=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| src_oidx_targets=[F.softmax(src_oidx.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| loss_h_consistency=[0.5*(F.kl_div(src_hidx_input.log(),src_hidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_hidx_target.log(),src_hidx_input.clone().detach(),reduction='batchmean')) for src_hidx_input,src_hidx_target in zip(src_hidx_inputs,src_hidx_targets)] | |
| loss_o_consistency=[0.5*(F.kl_div(src_oidx_input.log(),src_oidx_target.clone().detach(),reduction='batchmean')+F.kl_div(src_oidx_target.log(),src_oidx_input.clone().detach(),reduction='batchmean')) for src_oidx_input,src_oidx_target in zip(src_oidx_inputs,src_oidx_targets)] | |
| loss_h_consistency=torch.mean(torch.stack(loss_h_consistency)) | |
| loss_o_consistency=torch.mean(torch.stack(loss_o_consistency)) | |
| losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o,'loss_h_consistency':loss_h_consistency,'loss_o_consistency':loss_o_consistency} | |
| else: | |
| losses = {'loss_hidx': loss_h, 'loss_oidx': loss_o} | |
| return losses | |
| # >>> HOI Losses 2 : pair actions | |
| def loss_pair_actions(self, outputs, targets, hoi_indices, num_boxes,use_consis): | |
| assert 'pred_actions' in outputs | |
| src_actions = outputs['pred_actions'].flatten(end_dim=1) | |
| hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
| # idx = self._get_src_permutation_idx(hoi_indices) | |
| idx = self._get_src_permutation_idx(hoi_ind) | |
| # Construct Target -------------------------------------------------------------------------------------------------------------- | |
| target_classes_o = torch.cat([t["pair_actions"][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
| target_classes = torch.full(src_actions.shape, 0, dtype=torch.float32, device=src_actions.device) | |
| target_classes[..., -1] = 1 # the last index for no-interaction is '1' if a label exists | |
| pos_classes = torch.full(target_classes[idx].shape, 0, dtype=torch.float32, device=src_actions.device) # else, the last index for no-interaction is '0' | |
| pos_classes[:, :-1] = target_classes_o.float() | |
| target_classes[idx] = pos_classes | |
| # -------------------------------------------------------------------------------------------------------------------------------- | |
| # BCE Loss ----------------------------------------------------------------------------------------------------------------------- | |
| logits = src_actions.sigmoid() | |
| loss_bce = F.binary_cross_entropy(logits[..., self.valid_ids], target_classes[..., self.valid_ids], reduction='none') | |
| p_t = logits[..., self.valid_ids] * target_classes[..., self.valid_ids] + (1 - logits[..., self.valid_ids]) * (1 - target_classes[..., self.valid_ids]) | |
| loss_bce = ((1-p_t)**2 * loss_bce) | |
| alpha_t = 0.25 * target_classes[..., self.valid_ids] + (1 - 0.25) * (1 - target_classes[..., self.valid_ids]) | |
| loss_focal = alpha_t * loss_bce | |
| loss_act = loss_focal.sum() / max(target_classes[..., self.valid_ids[:-1]].sum(), 1) | |
| # -------------------------------------------------------------------------------------------------------------------------------- | |
| #Consistency loss | |
| if use_consis: | |
| consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices] | |
| src_action_inputs=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[0]]) for i,consistency_idx in enumerate(consistency_idxs)] | |
| src_action_targets=[F.logsigmoid(outputs['pred_actions'][i][consistency_idx[1]]) for i,consistency_idx in enumerate(consistency_idxs)] | |
| loss_action_consistency=[F.mse_loss(src_action_input,src_action_target) for src_action_input,src_action_target in zip(src_action_inputs,src_action_targets)] | |
| loss_action_consistency=torch.mean(torch.stack(loss_action_consistency)) | |
| # import pdb;pdb.set_trace() | |
| losses = {'loss_act': loss_act,'loss_act_consistency':loss_action_consistency} | |
| else: | |
| losses = {'loss_act': loss_act} | |
| return losses | |
| # HOI Losses 3 : action targets | |
| def loss_pair_targets(self, outputs, targets, hoi_indices, num_interactions,use_consis, log=True): | |
| assert 'pred_obj_logits' in outputs | |
| src_logits = outputs['pred_obj_logits'] | |
| nu,q,hd=outputs['pred_obj_logits'].shape | |
| hoi_ind=list(itertools.chain.from_iterable(hoi_indices)) | |
| idx = self._get_src_permutation_idx(hoi_ind) | |
| target_classes_o = torch.cat([t['pair_targets'][J] for t, hoi_indice in zip(targets, hoi_indices) for (_,J) in hoi_indice]) | |
| pad_tgt = -1 # src_logits.shape[2]-1 | |
| target_classes = torch.full(src_logits.shape[:2], pad_tgt, dtype=torch.int64, device=src_logits.device) | |
| target_classes[idx] = target_classes_o | |
| loss_obj_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.tgt_empty_weight, ignore_index=-1) | |
| #consistency | |
| if use_consis: | |
| consistency_idxs=[self._get_consistency_src_permutation_idx(hoi_indice) for hoi_indice in hoi_indices] | |
| src_logits_inputs=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[0]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| src_logits_targets=[F.softmax(src_logits.view(-1,self.num_path,q,hd)[i][consistency_idx[1]],-1) for i,consistency_idx in enumerate(consistency_idxs)] | |
| loss_tgt_consistency=[0.5*(F.kl_div(src_logit_input.log(),src_logit_target.clone().detach(),reduction='batchmean')+F.kl_div(src_logit_target.log(),src_logit_input.clone().detach(),reduction='batchmean')) for src_logit_input,src_logit_target in zip(src_logits_inputs,src_logits_targets)] | |
| loss_tgt_consistency=torch.mean(torch.stack(loss_tgt_consistency)) | |
| losses = {'loss_tgt': loss_obj_ce,"loss_tgt_label_consistency":loss_tgt_consistency} | |
| else: | |
| losses = {'loss_tgt': loss_obj_ce} | |
| if log: | |
| ignore_idx = (target_classes_o != -1) | |
| losses['obj_class_error'] = 100 - accuracy(src_logits[idx][ignore_idx, :-1], target_classes_o[ignore_idx])[0] | |
| # losses['obj_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] | |
| return losses | |
| def _get_src_permutation_idx(self, indices): | |
| # permute predictions following indices | |
| batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) | |
| src_idx = torch.cat([src for (src, _) in indices]) | |
| return batch_idx, src_idx | |
| def _get_consistency_src_permutation_idx(self, indices): | |
| all_tgt=torch.cat([j for(_,j) in indices]).unique() | |
| path_idxs=[torch.cat([torch.tensor([i]) for i,(_,t)in enumerate(indices) if (t==tgt).any()]) for tgt in all_tgt] | |
| q_idxs=[torch.cat([s[t==tgt] for (s,t)in indices]) for tgt in all_tgt] | |
| path_idxs=torch.cat([torch.combinations(path_idx) for path_idx in path_idxs if len(path_idx)>1]) | |
| q_idxs=torch.cat([torch.combinations(q_idx) for q_idx in q_idxs if len(q_idx)>1]) | |
| return (path_idxs[:,0],q_idxs[:,0]),(path_idxs[:,1],q_idxs[:,1]) | |
| def _get_tgt_permutation_idx(self, indices): | |
| # permute targets following indices | |
| batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) | |
| tgt_idx = torch.cat([tgt for (_, tgt) in indices]) | |
| return batch_idx, tgt_idx | |
| # ***************************************************************************** | |
| # >>> DETR Losses | |
| def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): | |
| loss_map = { | |
| 'labels': self.loss_labels, | |
| 'cardinality': self.loss_cardinality, | |
| 'boxes': self.loss_boxes | |
| } | |
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
| return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) | |
| # >>> HOTR Losses | |
| def get_HOI_loss(self, loss, outputs, targets, indices, num_boxes,use_consis, **kwargs): | |
| loss_map = { | |
| 'pair_labels': self.loss_pair_labels, | |
| 'pair_actions': self.loss_pair_actions | |
| } | |
| if self.dataset_file == 'hico-det': loss_map['pair_targets'] = self.loss_pair_targets | |
| assert loss in loss_map, f'do you really want to compute {loss} loss?' | |
| return loss_map[loss](outputs, targets, indices, num_boxes,use_consis, **kwargs) | |
| # ***************************************************************************** | |
| def forward(self, outputs, targets, log=False): | |
| """ This performs the loss computation. | |
| Parameters: | |
| outputs: dict of tensors, see the output specification of the model for the format | |
| targets: list of dicts, such that len(targets) == batch_size. | |
| The expected keys in each dict depends on the losses applied, see each loss' doc | |
| """ | |
| outputs_without_aux = {k: v for k, v in outputs.items() if (k != 'aux_outputs' and k != 'hoi_aux_outputs')} | |
| # Retrieve the matching between the outputs of the last layer and the targets | |
| indices = self.matcher(outputs_without_aux, targets) | |
| if self.HOI_losses is not None: | |
| input_targets = [copy.deepcopy(target) for target in targets] | |
| hoi_indices, hoi_targets = self.HOI_matcher(outputs_without_aux, input_targets, indices, log) | |
| # Compute the average number of target boxes accross all nodes, for normalization purposes | |
| num_boxes = sum(len(t["labels"]) for t in targets) | |
| num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) | |
| if is_dist_avail_and_initialized(): | |
| torch.distributed.all_reduce(num_boxes) | |
| num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() | |
| # Compute all the requested losses | |
| losses = {} | |
| for loss in self.losses: | |
| losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) | |
| # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. | |
| if 'aux_outputs' in outputs: | |
| for i, aux_outputs in enumerate(outputs['aux_outputs']): | |
| indices = self.matcher(aux_outputs, targets) | |
| for loss in self.losses: | |
| if loss == 'masks': | |
| # Intermediate masks losses are too costly to compute, we ignore them. | |
| continue | |
| kwargs = {} | |
| if loss == 'labels': | |
| # Logging is enabled only for the last layer | |
| kwargs = {'log': False} | |
| l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) | |
| l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| # HOI detection losses | |
| if self.HOI_losses is not None: | |
| for loss in self.HOI_losses: | |
| losses.update(self.get_HOI_loss(loss, outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis)) | |
| # if self.dataset_file == 'hico-det': losses['loss_oidx'] += losses['loss_tgt'] | |
| if 'hoi_aux_outputs' in outputs: | |
| for i, aux_outputs in enumerate(outputs['hoi_aux_outputs']): | |
| input_targets = [copy.deepcopy(target) for target in targets] | |
| hoi_indices, targets_for_aux = self.HOI_matcher(aux_outputs, input_targets, indices, log) | |
| for loss in self.HOI_losses: | |
| kwargs = {} | |
| if loss == 'pair_targets': kwargs = {'log': False} # Logging is enabled only for the last layer | |
| l_dict = self.get_HOI_loss(loss, aux_outputs, hoi_targets, hoi_indices, num_boxes,self.use_consis, **kwargs) | |
| l_dict = {k + f'_{i}': v for k, v in l_dict.items()} | |
| losses.update(l_dict) | |
| # if self.dataset_file == 'hico-det': losses[f'loss_oidx_{i}'] += losses[f'loss_tgt_{i}'] | |
| return losses |