qqc1989's picture
Upload 114 files
ed861ec verified
import torch
def prepare_for_mask(kpt_mask):
tgt_size2 = 50 * 69
attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
group_bbox_kpt = 69
num_group=50
for matchj in range(num_group * group_bbox_kpt):
sj = (matchj // group_bbox_kpt) * group_bbox_kpt
ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
if sj > 0:
attn_mask2[:,:,matchj, :sj] = True
if ej < num_group * group_bbox_kpt:
attn_mask2[:,:,matchj, ej:] = True
bs, length = kpt_mask.shape
equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
for idx in range(num_group):
start_idx = idx * length
end_idx = (idx + 1) * length
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True
input_query_label = None
input_query_bbox = None
attn_mask = None
dn_meta = None
return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta
def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):
if dn_meta and dn_meta['pad_size'] > 0:
output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]
outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]
out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
if aux_loss:
out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
dn_meta['output_known_lbs_bboxes'] = out
return outputs_class, outputs_coord