File size: 1,287 Bytes
1c1cd5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn

from segment_anything.modeling import ImageEncoderViT
from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition


class BlockWrapper(nn.Module):
    def __init__(self, block: Block):
        super().__init__()
        self.block = block

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.block.norm1(x)
        # Window partition
        if self.block.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.block.window_size)
            x = self.block.attn(x)
            # Reverse window partition
            x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
        else:
            x = self.block.attn(x)

        x = shortcut + x
        x = x + self.block.mlp(self.block.norm2(x))

        return x


class ImageEncoderViTWrapper(nn.Module):
    def __init__(self, image_encoder: ImageEncoderViT):
        super().__init__()
        self.image_encoder = image_encoder

    def change_block(self):
        block_wrappers = nn.ModuleList()
        for block in self.image_encoder.blocks:
            block_wrappers.append(BlockWrapper(block))
        self.image_encoder.blocks = block_wrappers