File size: 13,472 Bytes
032e687
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import torch
import torch.nn as nn

from .sam2_predictor import SAM2VideoPredictor
from .sam2_implementation.modeling.backbones.hieradet import Hiera
from .sam2_implementation.modeling.backbones.image_encoder import FpnNeck, ImageEncoder
from .sam2_implementation.modeling.position_encoding import PositionEmbeddingSine
from .sam2_implementation.modeling.memory_encoder import MemoryEncoder
from .sam2_implementation.modeling.memory_attention import MemoryAttentionLayer, MemoryAttention
from .sam2_implementation.modeling.sam.transformer import RoPEAttention
from .sam2_implementation.modeling.memory_encoder import MaskDownSampler
from .sam2_implementation.modeling.memory_encoder import Fuser
from .sam2_implementation.modeling.memory_encoder import CXBlock

def load_checkpoint_with_prefix(filename, prefix=None, map_location='cpu', logger='current'):
    """Load partial pretrained model with specific prefix.

    Args:
        prefix (str): The prefix of sub-module.
        filename (str): Accept local filepath, URL, ``torchvision://xxx``,
            ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
            details.
        map_location (str | None): Same as :func:`torch.load`.
            Defaults to None.
        logger: logger

    Returns:
        dict or OrderedDict: The loaded checkpoint.
    """
    checkpoint = torch.load(filename, map_location=map_location)

    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif 'model' in checkpoint:
        state_dict = checkpoint['model']
    else:
        state_dict = checkpoint
    if not prefix:
        return state_dict
    if not prefix.endswith('.'):
        prefix += '.'
    prefix_len = len(prefix)

    state_dict = {
        k[prefix_len:]: v
        for k, v in state_dict.items() if k.startswith(prefix)
    }

    assert state_dict, f'{prefix} is not in the pretrained model'
    return state_dict

def load_state_dict_to_model(model, state_dict,  logger='current'):
    missing_keys, unexpected_keys = model.load_state_dict(state_dict)
    if missing_keys:
        print(missing_keys)
        raise RuntimeError()
    if unexpected_keys:
        print(unexpected_keys)
        raise RuntimeError()
    print("Loaded checkpoint successfully")

class SAM2(nn.Module):
    def __init__(
            self,
            ckpt_path: str = None,
    ):
        super().__init__()

        image_encoder = self.build_image_encoder()
        memory_attention = self.build_memory_attention()
        memory_encoder = self.build_memory_encoder()
        sam2_model = SAM2VideoPredictor(
            image_encoder=image_encoder,
            memory_attention=memory_attention,
            memory_encoder=memory_encoder,
            num_maskmem = 7,
            image_size = 1024,
            # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
            sigmoid_scale_for_mem_enc = 20.0,
            sigmoid_bias_for_mem_enc = -10.0,
            use_mask_input_as_output_without_sam = True,
            # Memory
            directly_add_no_mem_embed = True,
            # use high-resolution feature map in the SAM mask decoder
            use_high_res_features_in_sam = True,
            # output 3 masks on the first click on initial conditioning frames
            multimask_output_in_sam = True,
            # SAM heads
            iou_prediction_use_sigmoid = True,
            # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
            use_obj_ptrs_in_encoder = True,
            add_tpos_enc_to_obj_ptrs = False,
            only_obj_ptrs_in_the_past_for_eval = True,
            # object occlusion prediction
            pred_obj_scores = True,
            pred_obj_scores_mlp = True,
            fixed_no_obj_ptr = True,
            # multimask tracking settings
            multimask_output_for_tracking = True,
            use_multimask_token_for_obj_ptr = True,
            multimask_min_pt_num = 0,
            multimask_max_pt_num = 1,
            use_mlp_for_obj_ptr_proj = True,
            # Compilation flag
            compile_image_encoder = False,
            sam_mask_decoder_extra_args={
                'dynamic_multimask_via_stability':True,
                'dynamic_multimask_stability_delta': 0.05,
                'dynamic_multimask_stability_thresh': 0.98,
            }
        )
        if ckpt_path is not None:
            state_dict = load_checkpoint_with_prefix(ckpt_path)
            load_state_dict_to_model(sam2_model, state_dict)

        self.sam2_model = sam2_model

        self.hidden_dim = self.sam2_model.hidden_dim

        self.img_mean = (0.485, 0.456, 0.406)
        self.img_std = (0.229, 0.224, 0.225)

    def build_image_encoder(self):
        def build_trunk():
            embed_dim = 144
            num_heads = 2
            stages = [2, 6, 36, 4]
            global_att_blocks = [23, 33, 43]
            window_pos_embed_bkg_spatial_size = [7, 7]
            window_spec = [8, 4, 16, 8]
            ret = Hiera(
                embed_dim=embed_dim,
                num_heads=num_heads,
                stages=stages,
                global_att_blocks=global_att_blocks,
                window_pos_embed_bkg_spatial_size=window_pos_embed_bkg_spatial_size,
                window_spec=window_spec,
            )
            return ret
        def build_neck():
            def build_position_encoding():
                num_pos_feats = 256
                normalize = True
                scale = None
                temperature = 10000
                ret = PositionEmbeddingSine(
                    num_pos_feats=num_pos_feats,
                    normalize=normalize,
                    scale=scale,
                    temperature=temperature,
                )
                return ret
            d_model = 256
            backbone_channel_list = [1152, 576, 288, 144]
            fpn_top_down_levels = [2, 3]  # output level 0 and 1 directly use the backbone features
            fpn_interp_model = 'nearest'
            position_encoding = build_position_encoding()
            ret = FpnNeck(
                d_model=d_model,
                position_encoding=position_encoding,
                backbone_channel_list=backbone_channel_list,
                fpn_top_down_levels=fpn_top_down_levels,
                fpn_interp_model=fpn_interp_model,
            )
            return ret
        scalp = 1
        trunk = build_trunk()
        neck = build_neck()
        ret = ImageEncoder(scalp=scalp, trunk=trunk, neck=neck)
        return ret

    def build_memory_attention(self):
        def build_layer():
            def build_self_attention():
                rope_theta = 10000.0
                feat_sizes = [32, 32]
                embedding_dim = 256
                num_heads = 1
                downsample_rate = 1
                dropout = 0.1
                ret = RoPEAttention(
                    rope_theta=rope_theta,
                    feat_sizes=feat_sizes,
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    downsample_rate=downsample_rate,
                    dropout=dropout
                )
                return ret
            def build_cross_attention():
                rope_theta = 10000.0
                feat_sizes = [32, 32]
                rope_k_repeat = True
                embedding_dim = 256
                num_heads = 1
                downsample_rate = 1
                dropout = 0.1
                kv_in_dim = 64
                ret = RoPEAttention(
                    rope_theta=rope_theta,
                    feat_sizes=feat_sizes,
                    rope_k_repeat=rope_k_repeat,
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    downsample_rate=downsample_rate,
                    dropout=dropout,
                    kv_in_dim=kv_in_dim
                )
                return ret
            activation = 'relu'
            dim_feedforward = 2048
            dropout = 0.1
            pos_enc_at_attn = False
            d_model = 256
            pos_enc_at_cross_attn_keys = True
            pos_enc_at_cross_attn_queries = False
            self_attention = build_self_attention()
            cross_attention = build_cross_attention()
            ret = MemoryAttentionLayer(
                activation=activation,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                pos_enc_at_attn=pos_enc_at_attn,
                d_model=d_model,
                pos_enc_at_cross_attn_queries=pos_enc_at_cross_attn_queries,
                pos_enc_at_cross_attn_keys=pos_enc_at_cross_attn_keys,
                self_attention=self_attention,
                cross_attention=cross_attention,
            )
            return ret
        d_model = 256
        pos_enc_at_input = True
        num_layers = 4
        layer = build_layer()
        ret = MemoryAttention(
            d_model=d_model,
            pos_enc_at_input=pos_enc_at_input,
            num_layers=num_layers,
            layer=layer,
        )
        return ret

    def build_memory_encoder(self):
        def build_position_encoding():
            num_pos_feats = 64
            normalize = True
            scale = None
            temperature = 10000
            ret = PositionEmbeddingSine(
                num_pos_feats=num_pos_feats,
                normalize=normalize,
                scale=scale,
                temperature=temperature,
            )
            return ret

        def build_mask_downsampler():
            kernel_size = 3
            stride = 2
            padding = 1
            ret = MaskDownSampler(
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
            )
            return ret

        def build_fuser():
            def build_layer():
                dim = 256
                kernel_size = 7
                padding = 3
                layer_scale_init_value = 1e-6
                use_dwconv = True  # depth-wise convs
                ret = CXBlock(
                    dim=dim, kernel_size=kernel_size,
                    padding=padding, layer_scale_init_value=layer_scale_init_value,
                    use_dwconv=use_dwconv,
                )
                return ret

            num_layers = 2
            layer = build_layer()
            ret = Fuser(
                layer=layer,
                num_layers=num_layers
            )
            return ret

        out_dim = 64
        position_encoding = build_position_encoding()
        mask_downsampler = build_mask_downsampler()
        fuser = build_fuser()
        ret = MemoryEncoder(
            out_dim=out_dim,
            position_encoding=position_encoding,
            mask_downsampler=mask_downsampler,
            fuser=fuser,
        )
        return ret

    def inject_language_embd(self, inference_state, language_embd):
        num_frame = len(language_embd)
        num_obj = len(language_embd[0])
        mask_out = []
        for frame_idx in range(num_frame):
            frame_mask_out = []
            for obj_idx in range(num_obj):
                _language_embd = language_embd[frame_idx][obj_idx][None][None]
                _, _, out_mask_logits = self.sam2_model.add_language_embd(inference_state, frame_idx, obj_idx + 100, _language_embd)
                frame_mask_out.append(out_mask_logits)
            frame_mask_out = torch.cat(frame_mask_out, dim=1)
            mask_out.append(frame_mask_out)
        mask_out = torch.cat(mask_out, dim=0)
        return mask_out


    def language_embd_inference(self, inference_state, language_embd):
        num_frame = len(language_embd)
        num_obj = len(language_embd[0])
        mask_out = []
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            for frame_idx in range(num_frame):
                frame_mask_out = []

                for obj_idx in range(num_obj):
                    _language_embd = language_embd[frame_idx][obj_idx][None][None]
                    _, _, out_mask_logits = self.sam2_model.add_language_embd(
                        inference_state,
                        frame_idx,
                        obj_idx + 100,
                        _language_embd,
                        inference=True,
                    )
                    frame_mask_out.append(out_mask_logits)
                frame_mask_out = torch.cat(frame_mask_out, dim=1)
                mask_out.append(frame_mask_out)


            mask_out = []
            for out_frame_idx, out_obj_ids, out_mask_logits in self.sam2_model.propagate_in_video(inference_state):
                mask_out.append(out_mask_logits)
            mask_out = torch.cat(mask_out, dim=0)
        return mask_out

    def get_sam2_embeddings(self, images):
        return self.sam2_model.init_state(images)

    def forward(self, batch):
        raise NotImplementedError

    def preprocess_image(self, image: torch.Tensor, dtype=torch.float32) -> torch.Tensor:
        image = image / 255.

        img_mean = torch.tensor(self.img_mean, dtype=dtype, device=image.device)[:, None, None]
        img_std = torch.tensor(self.img_std, dtype=dtype, device=image.device)[:, None, None]
        image -= img_mean
        image /= img_std

        return image