import torch import torch.nn as nn from segment_anything.modeling import ImageEncoderViT from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition class BlockWrapper(nn.Module): def __init__(self, block: Block): super().__init__() self.block = block def forward(self, x: torch.Tensor) -> torch.Tensor: shortcut = x x = self.block.norm1(x) # Window partition if self.block.window_size > 0: H, W = x.shape[1], x.shape[2] x, pad_hw = window_partition(x, self.block.window_size) x = self.block.attn(x) # Reverse window partition x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) else: x = self.block.attn(x) x = shortcut + x x = x + self.block.mlp(self.block.norm2(x)) return x class ImageEncoderViTWrapper(nn.Module): def __init__(self, image_encoder: ImageEncoderViT): super().__init__() self.image_encoder = image_encoder def change_block(self): block_wrappers = nn.ModuleList() for block in self.image_encoder.blocks: block_wrappers.append(BlockWrapper(block)) self.image_encoder.blocks = block_wrappers