File size: 4,676 Bytes
c75adab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

def generate_position_tensor(h, w, box_xy, raw_h, raw_w):
    box_x = box_xy[2] - box_xy[0]
    box_y = box_xy[3] - box_xy[1]
    x_tensor = torch.arange(w, dtype=torch.float32) / w  
    x_tensor = x_tensor.repeat(h, 1) 
    x_tensor = (x_tensor*box_x+box_xy[0])/(raw_w-1)

    y_tensor = torch.arange(h, dtype=torch.float32) / h 
    y_tensor = y_tensor.unsqueeze(1).repeat(1, w) 
    y_tensor = (y_tensor*box_y+box_xy[1])/(raw_h-1)

    tensor_3d = torch.stack([x_tensor, y_tensor], dim=2)
    return tensor_3d

class MaskExtractor(nn.Module):
    def __init__(self, config, mm_hidden_size, depth=2):
        super(MaskExtractor, self).__init__()
        self.mask_pooling = MaskPooling()
        modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
        for _ in range(1, depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        self.feat_linear =  nn.Sequential(*modules)
        self.pos_linear = nn.Linear(2, mm_hidden_size)

    def forward(self, feats, masks, box_params, mask_num):
        query_feats = []
        
        if masks is None: #infer
            return None
            # masks = torch.zeros((1, 1, 336, 336)).to(feats.device).float()

        num_imgs = len(masks)
        image_idx = 0
        for idx in range(num_imgs):
            if masks[idx]==None:
                continue
            for mask_idx in range(len(masks[idx])):
                mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
                box_param = box_params[idx][mask_idx]
                box_xy, raw_h, raw_w = box_param
                if len(mask[0])==0:
                    print('mask error')
                    mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()

                feat = feats[image_idx].unsqueeze(0)
                pos_emb = generate_position_tensor(feat.shape[1], feat.shape[2], box_xy, raw_h, raw_w).unsqueeze(0).to(feat)
                pos_emb = self.pos_linear(pos_emb)
                feat = feat+pos_emb
                
                image_idx+=1
                
                # h, w = feat.shape[1:3]
                feat = feat.permute(0,3,1,2)

                feat = feat.to(mask.dtype)
                
                mask_feat_raw = self.mask_pooling(feat, mask, mask_token_num=mask_num) # [n, 1024]

                query_feats.append(mask_feat_raw)
        if len(query_feats)==0:
            return None
        mask_feats = torch.cat(query_feats, dim=0)
        mask_feats = mask_feats.to(feats[0].dtype)
        mask_feats_linear = self.feat_linear(mask_feats)
        return mask_feats_linear

def kmeans_fast(tokens, num_clusters=10, num_iterations=5):
    n, d = tokens.shape
    centroids = tokens[torch.randperm(n)[:num_clusters]]

    for _ in range(num_iterations):
        tokens_expand = tokens.unsqueeze(1)  # [n, 1, d]
        centroids_expand = centroids.unsqueeze(0)  # [1, num_clusters, d]
        
        distances = torch.sum((tokens_expand - centroids_expand) ** 2, dim=2)  # [n, num_clusters]
        
        labels = torch.argmin(distances, dim=1)  # [n]

        new_centroids = torch.stack([tokens[labels == i].mean(dim=0) if tokens[labels == i].size(0) > 0 else centroids[i] for i in range(num_clusters)])
        
        if torch.allclose(centroids, new_centroids, atol=1e-6):
            break
        
        centroids = new_centroids
    
    return centroids

class MaskPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, mask, mask_token_num=10):

        if not x.shape[-2:] == mask.shape[-2:]:
            # reshape mask to x
            x = F.interpolate(x, size=mask.shape[-2:], mode='bilinear', align_corners=False)
            # mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
        if not x.device == mask.device:
            mask = mask.to(x.device)
        # b, c, h ,w = x.shape
        # b, q, h, w = mask.shape
        mask = (mask > 0).to(mask.dtype)
        mask = mask.permute(1,0,2,3)
        denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
       
        mask_emb = x * mask
        mask = torch.any(mask_emb != 0, dim=(0, 1))
        mask_emb = mask_emb[:,:, mask]
        mask_embedding = mask_emb[0].permute(1,0)

        if len(mask_embedding)>mask_token_num: #FIXME
            mask_embedding = kmeans_fast(mask_embedding, mask_token_num)
        return mask_embedding


def build_region_encoder(config, mm_hidden_size):

    return MaskExtractor(config, mm_hidden_size)