File size: 2,001 Bytes
ed861ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch


def prepare_for_mask(kpt_mask):


    tgt_size2 = 50 * 69
    attn_mask2 = torch.ones(kpt_mask.shape[0], 8, tgt_size2, tgt_size2).to('cuda') < 0
    group_bbox_kpt = 69
    num_group=50
    for matchj in range(num_group * group_bbox_kpt):
        sj = (matchj // group_bbox_kpt) * group_bbox_kpt
        ej = (matchj // group_bbox_kpt + 1)*group_bbox_kpt
        if sj > 0:
            attn_mask2[:,:,matchj, :sj] = True
        if ej < num_group * group_bbox_kpt:
            attn_mask2[:,:,matchj, ej:] = True


    bs, length = kpt_mask.shape
    equal_mask = kpt_mask[:, :, None] == kpt_mask[:, None, :]
    equal_mask= equal_mask.unsqueeze(1).repeat(1,8,1,1)
    for idx in range(num_group):
        start_idx = idx * length
        end_idx = (idx + 1) * length
        attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][equal_mask] = False
        attn_mask2[:, :,start_idx:end_idx, start_idx:end_idx][~equal_mask] = True




    input_query_label = None
    input_query_bbox = None
    attn_mask = None
    dn_meta = None

    return input_query_label, input_query_bbox, attn_mask, attn_mask2.flatten(0,1), dn_meta


def post_process(outputs_class, outputs_coord, dn_meta, aux_loss, _set_aux_loss):

    if dn_meta and dn_meta['pad_size'] > 0:

        output_known_class = [outputs_class_i[:, :dn_meta['pad_size'], :] for outputs_class_i in outputs_class]
        output_known_coord = [outputs_coord_i[:, :dn_meta['pad_size'], :] for outputs_coord_i in outputs_coord]

        outputs_class = [outputs_class_i[:, dn_meta['pad_size']:, :] for outputs_class_i in outputs_class]
        outputs_coord = [outputs_coord_i[:, dn_meta['pad_size']:, :] for outputs_coord_i in outputs_coord]

        out = {'pred_logits': output_known_class[-1], 'pred_boxes': output_known_coord[-1]}
        if aux_loss:
            out['aux_outputs'] = _set_aux_loss(output_known_class, output_known_coord)
        dn_meta['output_known_lbs_bboxes'] = out
    return outputs_class, outputs_coord