khsyee's picture
Change using wrapper. but fail
1c1cd5e
raw
history blame
1.29 kB
import torch
import torch.nn as nn
from segment_anything.modeling import ImageEncoderViT
from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition
class BlockWrapper(nn.Module):
def __init__(self, block: Block):
super().__init__()
self.block = block
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.block.norm1(x)
# Window partition
if self.block.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.block.window_size)
x = self.block.attn(x)
# Reverse window partition
x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
else:
x = self.block.attn(x)
x = shortcut + x
x = x + self.block.mlp(self.block.norm2(x))
return x
class ImageEncoderViTWrapper(nn.Module):
def __init__(self, image_encoder: ImageEncoderViT):
super().__init__()
self.image_encoder = image_encoder
def change_block(self):
block_wrappers = nn.ModuleList()
for block in self.image_encoder.blocks:
block_wrappers.append(BlockWrapper(block))
self.image_encoder.blocks = block_wrappers