|
|
from rscd.models.backbones.vmamba import VSSM, LayerNorm2d |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Backbone_VSSM(VSSM): |
|
|
def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer='ln2d', **kwargs): |
|
|
|
|
|
kwargs.update(norm_layer=norm_layer) |
|
|
super().__init__(**kwargs) |
|
|
self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) |
|
|
_NORMLAYERS = dict( |
|
|
ln=nn.LayerNorm, |
|
|
ln2d=LayerNorm2d, |
|
|
bn=nn.BatchNorm2d, |
|
|
) |
|
|
norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) |
|
|
|
|
|
self.out_indices = out_indices |
|
|
for i in out_indices: |
|
|
layer = norm_layer(self.dims[i]) |
|
|
layer_name = f'outnorm{i}' |
|
|
self.add_module(layer_name, layer) |
|
|
|
|
|
del self.classifier |
|
|
self.load_pretrained(pretrained) |
|
|
|
|
|
def load_pretrained(self, ckpt=None, key="model"): |
|
|
if ckpt is None: |
|
|
return |
|
|
|
|
|
try: |
|
|
_ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) |
|
|
print(f"Successfully load ckpt {ckpt}") |
|
|
incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) |
|
|
print(incompatibleKeys) |
|
|
except Exception as e: |
|
|
print(f"Failed loading checkpoint form {ckpt}: {e}") |
|
|
|
|
|
def forward(self, x): |
|
|
def layer_forward(l, x): |
|
|
x = l.blocks(x) |
|
|
y = l.downsample(x) |
|
|
return x, y |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
outs = [] |
|
|
for i, layer in enumerate(self.layers): |
|
|
o, x = layer_forward(layer, x) |
|
|
if i in self.out_indices: |
|
|
norm_layer = getattr(self, f'outnorm{i}') |
|
|
out = norm_layer(o) |
|
|
if not self.channel_first: |
|
|
out = out.permute(0, 3, 1, 2).contiguous() |
|
|
outs.append(out) |
|
|
|
|
|
if len(self.out_indices) == 0: |
|
|
return x |
|
|
|
|
|
return outs |
|
|
|
|
|
class CMBackbone(nn.Module): |
|
|
def __init__(self, pretrained, **kwargs): |
|
|
super(CMBackbone, self).__init__() |
|
|
self.encoder = Backbone_VSSM(out_indices=(0, 1, 2, 3), pretrained=pretrained, **kwargs) |
|
|
|
|
|
def forward(self, pre_data, post_data): |
|
|
|
|
|
pre_features = self.encoder(pre_data) |
|
|
post_features = self.encoder(post_data) |
|
|
|
|
|
return [pre_features, post_features, pre_data.size()[-2:]] |