File size: 1,287 Bytes
1c1cd5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
from segment_anything.modeling import ImageEncoderViT
from segment_anything.modeling.image_encoder import Block, window_partition, window_unpartition
class BlockWrapper(nn.Module):
def __init__(self, block: Block):
super().__init__()
self.block = block
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.block.norm1(x)
# Window partition
if self.block.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.block.window_size)
x = self.block.attn(x)
# Reverse window partition
x = window_unpartition(x, self.block.window_size, pad_hw, (H, W))
else:
x = self.block.attn(x)
x = shortcut + x
x = x + self.block.mlp(self.block.norm2(x))
return x
class ImageEncoderViTWrapper(nn.Module):
def __init__(self, image_encoder: ImageEncoderViT):
super().__init__()
self.image_encoder = image_encoder
def change_block(self):
block_wrappers = nn.ModuleList()
for block in self.image_encoder.blocks:
block_wrappers.append(BlockWrapper(block))
self.image_encoder.blocks = block_wrappers
|