| from rscd.models.backbones.vmamba import VSSM, LayerNorm2d |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class Backbone_VSSM(VSSM): |
| def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer='ln2d', **kwargs): |
| |
| kwargs.update(norm_layer=norm_layer) |
| super().__init__(**kwargs) |
| self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) |
| _NORMLAYERS = dict( |
| ln=nn.LayerNorm, |
| ln2d=LayerNorm2d, |
| bn=nn.BatchNorm2d, |
| ) |
| norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) |
| |
| self.out_indices = out_indices |
| for i in out_indices: |
| layer = norm_layer(self.dims[i]) |
| layer_name = f'outnorm{i}' |
| self.add_module(layer_name, layer) |
|
|
| del self.classifier |
| self.load_pretrained(pretrained) |
|
|
| def load_pretrained(self, ckpt=None, key="model"): |
| if ckpt is None: |
| return |
| |
| try: |
| _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) |
| print(f"Successfully load ckpt {ckpt}") |
| incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) |
| print(incompatibleKeys) |
| except Exception as e: |
| print(f"Failed loading checkpoint form {ckpt}: {e}") |
|
|
| def forward(self, x): |
| def layer_forward(l, x): |
| x = l.blocks(x) |
| y = l.downsample(x) |
| return x, y |
|
|
| x = self.patch_embed(x) |
| outs = [] |
| for i, layer in enumerate(self.layers): |
| o, x = layer_forward(layer, x) |
| if i in self.out_indices: |
| norm_layer = getattr(self, f'outnorm{i}') |
| out = norm_layer(o) |
| if not self.channel_first: |
| out = out.permute(0, 3, 1, 2).contiguous() |
| outs.append(out) |
|
|
| if len(self.out_indices) == 0: |
| return x |
| |
| return outs |
| |
| class CMBackbone(nn.Module): |
| def __init__(self, pretrained, **kwargs): |
| super(CMBackbone, self).__init__() |
| self.encoder = Backbone_VSSM(out_indices=(0, 1, 2, 3), pretrained=pretrained, **kwargs) |
|
|
| def forward(self, pre_data, post_data): |
| |
| pre_features = self.encoder(pre_data) |
| post_features = self.encoder(post_data) |
|
|
| return [pre_features, post_features, pre_data.size()[-2:]] |