from rscd.models.backbones.vmamba import VSSM, LayerNorm2d import torch import torch.nn as nn class Backbone_VSSM(VSSM): def __init__(self, out_indices=(0, 1, 2, 3), pretrained=None, norm_layer='ln2d', **kwargs): # norm_layer='ln' kwargs.update(norm_layer=norm_layer) super().__init__(**kwargs) self.channel_first = (norm_layer.lower() in ["bn", "ln2d"]) _NORMLAYERS = dict( ln=nn.LayerNorm, ln2d=LayerNorm2d, bn=nn.BatchNorm2d, ) norm_layer: nn.Module = _NORMLAYERS.get(norm_layer.lower(), None) self.out_indices = out_indices for i in out_indices: layer = norm_layer(self.dims[i]) layer_name = f'outnorm{i}' self.add_module(layer_name, layer) del self.classifier self.load_pretrained(pretrained) def load_pretrained(self, ckpt=None, key="model"): if ckpt is None: return try: _ckpt = torch.load(open(ckpt, "rb"), map_location=torch.device("cpu")) print(f"Successfully load ckpt {ckpt}") incompatibleKeys = self.load_state_dict(_ckpt[key], strict=False) print(incompatibleKeys) except Exception as e: print(f"Failed loading checkpoint form {ckpt}: {e}") def forward(self, x): def layer_forward(l, x): x = l.blocks(x) y = l.downsample(x) return x, y x = self.patch_embed(x) outs = [] for i, layer in enumerate(self.layers): o, x = layer_forward(layer, x) # (B, H, W, C) if i in self.out_indices: norm_layer = getattr(self, f'outnorm{i}') out = norm_layer(o) if not self.channel_first: out = out.permute(0, 3, 1, 2).contiguous() outs.append(out) if len(self.out_indices) == 0: return x return outs class CMBackbone(nn.Module): def __init__(self, pretrained, **kwargs): super(CMBackbone, self).__init__() self.encoder = Backbone_VSSM(out_indices=(0, 1, 2, 3), pretrained=pretrained, **kwargs) def forward(self, pre_data, post_data): # Encoder processing pre_features = self.encoder(pre_data) post_features = self.encoder(post_data) return [pre_features, post_features, pre_data.size()[-2:]]