import torch def prepare_for_mask(kpt_mask): tgt_size2 = 50 * 69 attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0 group_bbox_kpt = 69 num_group=50 for matchj in range(num_group * group_bbox_kpt): sj = (matchj // group_bbox_kpt) * group_bbox_kpt ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt if sj > 0: attn_mask2[:,:,matchj, :sj] = True if ej < num_group * group_bbox_kpt: attn_mask2[:,:,matchj, ej:] = True bs, length = kpt_mask.shape equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :] equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1) for idx in range(num_group): start_idx = idx * length end_idx = (idx + 1) * length attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True input_query_label = None input_query_bbox = None attn_mask = None dn_meta = None return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss): if dn_meta and dn_meta['pad_size'] > 0: output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class] output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord] outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class] outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord] out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]} if aux_loss: out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord) dn_meta['output_known_lbs_bboxes'] = out return outputs_class, outputs_coord