File size: 2,474 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from rscd.models.backbones.vmamba import VSSM, LayerNorm2d
import torch
import torch.nn as nn
class Backbone_VSSM(VSSM):
def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer='ln2d', **kwargs):
# norm_layer='ln'
kwargs.update(norm_layer=norm_layer)
super().__init__(**kwargs)
self.channel_first = (norm_layer.lower() in ["bn", "ln2d"])
_NORMLAYERS = dict(
ln=nn.LayerNorm,
ln2d=LayerNorm2d,
bn=nn.BatchNorm2d,
)
norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None)
self.out_indices = out_indices
for i in out_indices:
layer = norm_layer(self.dims[i])
layer_name = f'outnorm{i}'
self.add_module(layer_name, layer)
del self.classifier
self.load_pretrained(pretrained)
def load_pretrained(self, ckpt=None, key="model"):
if ckpt is None:
return
try:
_ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu"))
print(f"Successfully load ckpt {ckpt}")
incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False)
print(incompatibleKeys)
except Exception as e:
print(f"Failed loading checkpoint form {ckpt}: {e}")
def forward(self, x):
def layer_forward(l, x):
x = l.blocks(x)
y = l.downsample(x)
return x, y
x = self.patch_embed(x)
outs = []
for i, layer in enumerate(self.layers):
o, x = layer_forward(layer, x) # (B, H, W, C)
if i in self.out_indices:
norm_layer = getattr(self, f'outnorm{i}')
out = norm_layer(o)
if not self.channel_first:
out = out.permute(0, 3, 1, 2).contiguous()
outs.append(out)
if len(self.out_indices) == 0:
return x
return outs
class CMBackbone(nn.Module):
def __init__(self, pretrained, **kwargs):
super(CMBackbone, self).__init__()
self.encoder = Backbone_VSSM(out_indices=(0, 1, 2, 3), pretrained=pretrained, **kwargs)
def forward(self, pre_data, post_data):
# Encoder processing
pre_features = self.encoder(pre_data)
post_features = self.encoder(post_data)
return [pre_features, post_features, pre_data.size()[-2:]] |