Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------ | |
| # HOTR official code : hotr/models/hotr.py | |
| # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved | |
| # ------------------------------------------------------------------------ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import copy | |
| import time | |
| import datetime | |
| from hotr.util.misc import NestedTensor, nested_tensor_from_tensor_list | |
| from .feed_forward import MLP | |
| class HOTR(nn.Module): | |
| def __init__(self, detr, | |
| num_hoi_queries, | |
| num_actions, | |
| interaction_transformer, | |
| augpath_name, | |
| share_dec_param, | |
| stop_grad_stage, | |
| freeze_detr, | |
| share_enc, | |
| pretrained_dec, | |
| temperature, | |
| hoi_aux_loss, | |
| return_obj_class=None): | |
| super().__init__() | |
| # * Instance Transformer --------------- | |
| self.detr = detr | |
| if freeze_detr: | |
| # if this flag is given, freeze the object detection related parameters of DETR | |
| for p in self.parameters(): | |
| p.requires_grad_(False) | |
| hidden_dim = detr.transformer.d_model | |
| # -------------------------------------- | |
| # * Interaction Transformer ----------------------------------------- | |
| self.num_queries = num_hoi_queries | |
| self.query_embed = nn.Embedding(self.num_queries, hidden_dim) | |
| self.H_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.O_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.action_embed = nn.Linear(hidden_dim, num_actions+1) | |
| # -------------------------------------------------------------------- | |
| # * HICO-DET FFN heads --------------------------------------------- | |
| self.return_obj_class = (return_obj_class is not None) | |
| if return_obj_class: self._valid_obj_ids = return_obj_class + [return_obj_class[-1]+1] | |
| # ------------------------------------------------------------------ | |
| # * Transformer Options --------------------------------------------- | |
| self.interaction_transformer = interaction_transformer | |
| if share_enc: # share encoder | |
| self.interaction_transformer.encoder = detr.transformer.encoder | |
| if pretrained_dec: # free variables for interaction decoder | |
| self.interaction_transformer.decoder = copy.deepcopy(detr.transformer.decoder) | |
| for p in self.interaction_transformer.decoder.parameters(): | |
| p.requires_grad_(True) | |
| # --------------------------------------------------------------------- | |
| #Augmented paths | |
| self.aug_paths = augpath_name | |
| if 'p2' in augpath_name: | |
| if not share_dec_param: | |
| self.xtoHO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| self.HOtoI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| else: | |
| self.xtoHO_interaction_decoder = self.interaction_transformer.decoder | |
| self.HOtoI_interaction_decoder = self.interaction_transformer.decoder | |
| self.query_embed_HOtoI = nn.Embedding(self.num_queries, hidden_dim) | |
| self.query_embed_HOtoI2 = nn.Embedding(self.num_queries, hidden_dim) | |
| self.H_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.O_Pointer_embed_HOtoI = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.action_embed_HOtoI = nn.Linear(hidden_dim, num_actions+1) | |
| if 'p3' in augpath_name: | |
| if not share_dec_param: | |
| self.xtoHI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| self.HItoO_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| else: | |
| self.xtoHI_interaction_decoder = self.interaction_transformer.decoder | |
| self.HItoO_interaction_decoder = self.interaction_transformer.decoder | |
| self.query_embed_HItoO = nn.Embedding(self.num_queries, hidden_dim) | |
| self.query_embed_HItoO2 = nn.Embedding(self.num_queries, hidden_dim) | |
| self.H_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.O_Pointer_embed_HItoO = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.action_embed_HItoO = nn.Linear(hidden_dim, num_actions+1) | |
| if 'p4' in augpath_name: | |
| if not share_dec_param: | |
| self.xtoOI_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| self.OItoH_interaction_decoder = copy.deepcopy(self.interaction_transformer.decoder) | |
| else: | |
| self.xtoOI_interaction_decoder = self.interaction_transformer.decoder | |
| self.OItoH_interaction_decoder = self.interaction_transformer.decoder | |
| self.query_embed_OItoH = nn.Embedding(self.num_queries, hidden_dim) | |
| self.query_embed_OItoH2 = nn.Embedding(self.num_queries, hidden_dim) | |
| self.H_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.O_Pointer_embed_OItoH = MLP(hidden_dim, hidden_dim, hidden_dim, 3) | |
| self.action_embed_OItoH = nn.Linear(hidden_dim, num_actions+1) | |
| self.stop_grad_stage = stop_grad_stage | |
| # * Loss Options ------------------- | |
| self.tau = temperature | |
| self.hoi_aux_loss = hoi_aux_loss | |
| # ---------------------------------- | |
| def forward(self, samples: NestedTensor): | |
| if isinstance(samples, (list, torch.Tensor)): | |
| samples = nested_tensor_from_tensor_list(samples) | |
| # >>>>>>>>>>>> BACKBONE LAYERS <<<<<<<<<<<<<<< | |
| features, pos = self.detr.backbone(samples) | |
| bs = features[-1].tensors.shape[0] | |
| src, mask = features[-1].decompose() | |
| assert mask is not None | |
| # ---------------------------------------------- | |
| # >>>>>>>>>>>> OBJECT DETECTION LAYERS <<<<<<<<<< | |
| start_time = time.time() | |
| hs, memory = self.detr.transformer(self.detr.input_proj(src), mask, self.detr.query_embed.weight, pos[-1]) | |
| inst_repr = F.normalize(hs[-1], p=2, dim=2) # instance representations | |
| # Prediction Heads for Object Detection | |
| outputs_class = self.detr.class_embed(hs) | |
| outputs_coord = self.detr.bbox_embed(hs).sigmoid() | |
| object_detection_time = time.time() - start_time | |
| # ----------------------------------------------- | |
| # >>>>>>>>>>>> HOI DETECTION LAYERS <<<<<<<<<<<<<<< | |
| start_time = time.time() | |
| assert hasattr(self, 'interaction_transformer'), "Missing Interaction Transformer." | |
| H_Pointer_reprs_bag,O_Pointer_reprs_bag,outputs_action=[],[],[] | |
| # main path P1 | |
| interaction_hs= self.interaction_transformer(self.detr.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # interaction representations | |
| H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed(interaction_hs), p=2, dim=-1)) | |
| O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed(interaction_hs), p=2, dim=-1)) | |
| outputs_action.append(self.action_embed(interaction_hs)) | |
| if len(self.aug_paths)!=0: | |
| pos_aug = pos[-1].flatten(2).permute(2, 0, 1) | |
| mask_aug = mask.flatten(1) | |
| # P2 (x->HO->I) | |
| if 'p2' in self.aug_paths: | |
| tgt_2 = torch.zeros_like(self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1)) | |
| hs_HOtoI = self.xtoHO_interaction_decoder(tgt_2,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| tgt_HOtoI = hs_HOtoI.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HOtoI.clone().detach().transpose(1,2)[-1] | |
| hs2_HOtoI = self.HOtoI_interaction_decoder(tgt_HOtoI,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HOtoI2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1)) | |
| O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HOtoI(hs_HOtoI), p=2, dim=-1)) | |
| outputs_action.append(self.action_embed_HOtoI(hs2_HOtoI)) | |
| # P3 (x->HI->O) | |
| if 'p3' in self.aug_paths: | |
| tgt_3 = torch.zeros_like(self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1)) | |
| hs_HItoO = self.xtoHI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| tgt_HItoO = hs_HItoO.transpose(1,2)[-1] if not self.stop_grad_stage else hs_HItoO.clone().detach().transpose(1,2)[-1] | |
| hs2_HItoO = self.HItoO_interaction_decoder(tgt_HItoO,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_HItoO2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_HItoO(hs_HItoO), p=2, dim=-1)) | |
| O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_HItoO(hs2_HItoO), p=2, dim=-1)) | |
| outputs_action.append(self.action_embed_HItoO(hs_HItoO)) | |
| # P4 (x->OI->H) | |
| if 'p4' in self.aug_paths: | |
| tgt_4 = torch.zeros_like(self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1)) | |
| hs_OItoH = self.xtoOI_interaction_decoder(tgt_3,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| tgt_OItoH = hs_OItoH.transpose(1,2)[-1] if not self.stop_grad_stage else hs_OItoH.clone().detach().transpose(1,2)[-1] | |
| hs2_OItoH = self.OItoH_interaction_decoder(tgt_OItoH,memory,memory_key_padding_mask=mask_aug, pos=pos_aug, query_pos=self.query_embed_OItoH2.weight.unsqueeze(1).repeat(1, bs, 1)).transpose(1,2) | |
| H_Pointer_reprs_bag.append(F.normalize(self.H_Pointer_embed_OItoH(hs2_OItoH), p=2, dim=-1)) | |
| O_Pointer_reprs_bag.append(F.normalize(self.O_Pointer_embed_OItoH(hs_OItoH), p=2, dim=-1)) | |
| outputs_action.append(self.action_embed_OItoH(hs_OItoH)) | |
| inst_repr_all=inst_repr.transpose(1,2).repeat(1+len(self.aug_paths),1,1) | |
| H_Pointer_reprs_bag=torch.cat(H_Pointer_reprs_bag,1) | |
| O_Pointer_reprs_bag=torch.cat(O_Pointer_reprs_bag,1) | |
| # import pdb;pdb.set_trace() | |
| outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr_all)) / self.tau for H_Pointer_repr in H_Pointer_reprs_bag] #(dec_layer,(1+len(aug))*bs,dec_q,hidden_dim) | |
| outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr_all)) / self.tau for O_Pointer_repr in O_Pointer_reprs_bag] | |
| outputs_action=torch.stack(outputs_action,dim=2) #(dec_layer,bs,1+#aug,dec_q,#action) | |
| # -------------------------------------------------- | |
| hoi_detection_time = time.time() - start_time | |
| hoi_recognition_time = max(hoi_detection_time - object_detection_time, 0) | |
| # ------------------------------------------------------------------- | |
| # [Target Classification] | |
| if self.return_obj_class: | |
| detr_logits = outputs_class[-1, ..., self._valid_obj_ids] | |
| o_indices = [output_oidx.max(-1)[-1].view(1+len(self.aug_paths),bs,self.num_queries).transpose(0,1) for output_oidx in outputs_oidx] | |
| obj_logit_stack = [torch.stack([detr_logits[batch_, o_idx, :] for batch_, o_idc in enumerate(o_indice) for o_idx in o_idc], 0) for o_indice in o_indices] | |
| outputs_obj_class = obj_logit_stack | |
| out = { | |
| "pred_logits": outputs_class[-1], | |
| "pred_boxes": outputs_coord[-1], | |
| "pred_hidx": outputs_hidx[-1], | |
| "pred_oidx": outputs_oidx[-1], | |
| "pred_actions": outputs_action[-1], | |
| "hoi_recognition_time": hoi_recognition_time, | |
| } | |
| if self.return_obj_class: out["pred_obj_logits"] = outputs_obj_class[-1] | |
| # import pdb;pdb.set_trace() | |
| if self.hoi_aux_loss: # auxiliary loss | |
| out['hoi_aux_outputs'] = \ | |
| self._set_aux_loss_with_tgt(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_obj_class) \ | |
| if self.return_obj_class else \ | |
| self._set_aux_loss(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action) | |
| return out | |
| def _set_aux_loss(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action): | |
| return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e} | |
| for a, b, c, d, e in zip( | |
| outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), | |
| outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), | |
| outputs_hidx[:-1], | |
| outputs_oidx[:-1], | |
| outputs_action[:-1])] | |
| def _set_aux_loss_with_tgt(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_tgt): | |
| return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e, 'pred_obj_logits': f} | |
| for a, b, c, d, e, f in zip( | |
| outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), | |
| outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), | |
| outputs_hidx[:-1], | |
| outputs_oidx[:-1], | |
| outputs_action[:-1], | |
| outputs_tgt[:-1])] |