|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from segment_anything.modeling import ImageEncoderViT |
|
|
from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition |
|
|
|
|
|
|
|
|
class BlockWrapper(nn.Module): |
|
|
def __init__(self, block: Block): |
|
|
super().__init__() |
|
|
self.block = block |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
shortcut = x |
|
|
x = self.block.norm1(x) |
|
|
|
|
|
if self.block.window_size > 0: |
|
|
H, W = x.shape[1], x.shape[2] |
|
|
x, pad_hw = window_partition(x, self.block.window_size) |
|
|
x = self.block.attn(x) |
|
|
|
|
|
x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) |
|
|
else: |
|
|
x = self.block.attn(x) |
|
|
|
|
|
x = shortcut + x |
|
|
x = x + self.block.mlp(self.block.norm2(x)) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class ImageEncoderViTWrapper(nn.Module): |
|
|
def __init__(self, image_encoder: ImageEncoderViT): |
|
|
super().__init__() |
|
|
self.image_encoder = image_encoder |
|
|
|
|
|
def change_block(self): |
|
|
block_wrappers = nn.ModuleList() |
|
|
for block in self.image_encoder.blocks: |
|
|
block_wrappers.append(BlockWrapper(block)) |
|
|
self.image_encoder.blocks = block_wrappers |
|
|
|