Spaces:
Runtime error
Runtime error
| # Copyright (c) Open-CD. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| from mmseg.models.backbones import ResNet | |
| from opencd.registry import MODELS | |
| class IA_ResNet(ResNet): | |
| """Interaction ResNet backbone. | |
| Args: | |
| interaction_cfg (Sequence[dict]): Interaction strategies for the stages. | |
| The length should be the same as `num_stages`. The details can be | |
| found in `opencd/models/utils/interaction_layer.py`. | |
| Default: (None, None, None, None). | |
| depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. | |
| in_channels (int): Number of input image channels. Default: 3. | |
| stem_channels (int): Number of stem channels. Default: 64. | |
| base_channels (int): Number of base channels of res layer. Default: 64. | |
| num_stages (int): Resnet stages, normally 4. Default: 4. | |
| strides (Sequence[int]): Strides of the first block of each stage. | |
| Default: (1, 2, 2, 2). | |
| dilations (Sequence[int]): Dilation of each stage. | |
| Default: (1, 1, 1, 1). | |
| out_indices (Sequence[int]): Output from which stages. | |
| Default: (0, 1, 2, 3). | |
| style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two | |
| layer is the 3x3 conv layer, otherwise the stride-two layer is | |
| the first 1x1 conv layer. Default: 'pytorch'. | |
| deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. | |
| Default: False. | |
| avg_down (bool): Use AvgPool instead of stride conv when | |
| downsampling in the bottleneck. Default: False. | |
| frozen_stages (int): Stages to be frozen (stop grad and set eval mode). | |
| -1 means not freezing any parameters. Default: -1. | |
| conv_cfg (dict | None): Dictionary to construct and config conv layer. | |
| When conv_cfg is None, cfg will be set to dict(type='Conv2d'). | |
| Default: None. | |
| norm_cfg (dict): Dictionary to construct and config norm layer. | |
| Default: dict(type='BN', requires_grad=True). | |
| norm_eval (bool): Whether to set norm layers to eval mode, namely, | |
| freeze running stats (mean and var). Note: Effect on Batch Norm | |
| and its variants only. Default: False. | |
| dcn (dict | None): Dictionary to construct and config DCN conv layer. | |
| When dcn is not None, conv_cfg must be None. Default: None. | |
| stage_with_dcn (Sequence[bool]): Whether to set DCN conv for each | |
| stage. The length of stage_with_dcn is equal to num_stages. | |
| Default: (False, False, False, False). | |
| plugins (list[dict]): List of plugins for stages, each dict contains: | |
| - cfg (dict, required): Cfg dict to build plugin. | |
| - position (str, required): Position inside block to insert plugin, | |
| options: 'after_conv1', 'after_conv2', 'after_conv3'. | |
| - stages (tuple[bool], optional): Stages to apply plugin, length | |
| should be same as 'num_stages'. | |
| Default: None. | |
| multi_grid (Sequence[int]|None): Multi grid dilation rates of last | |
| stage. Default: None. | |
| contract_dilation (bool): Whether contract first dilation of each layer | |
| Default: False. | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. Default: False. | |
| zero_init_residual (bool): Whether to use zero init for last norm layer | |
| in resblocks to let them behave as identity. Default: True. | |
| pretrained (str, optional): model pretrained path. Default: None. | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| Default: None. | |
| Example: | |
| >>> from opencd.models import IA_ResNet | |
| >>> import torch | |
| >>> self = IA_ResNet(depth=18) | |
| >>> self.eval() | |
| >>> inputs = torch.rand(1, 3, 32, 32) | |
| >>> level_outputs = self.forward(inputs, inputs) | |
| >>> for level_out in level_outputs: | |
| ... print(tuple(level_out.shape)) | |
| (1, 128, 8, 8) | |
| (1, 256, 4, 4) | |
| (1, 512, 2, 2) | |
| (1, 1024, 1, 1) | |
| """ | |
| def __init__(self, | |
| interaction_cfg=(None, None, None, None), | |
| **kwargs): | |
| super().__init__(**kwargs) | |
| assert self.num_stages == len(interaction_cfg), \ | |
| 'The length of the `interaction_cfg` should be same as the `num_stages`.' | |
| # cross-correlation | |
| self.ccs = [] | |
| for ia_cfg in interaction_cfg: | |
| if ia_cfg is None: | |
| ia_cfg = dict(type='TwoIdentity') | |
| self.ccs.append(MODELS.build(ia_cfg)) | |
| self.ccs = nn.ModuleList(self.ccs) | |
| def forward(self, x1, x2): | |
| """Forward function.""" | |
| def _stem_forward(x): | |
| if self.deep_stem: | |
| x = self.stem(x) | |
| else: | |
| x = self.conv1(x) | |
| x = self.norm1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| return x | |
| x1 = _stem_forward(x1) | |
| x2 = _stem_forward(x2) | |
| outs = [] | |
| for i, layer_name in enumerate(self.res_layers): | |
| res_layer = getattr(self, layer_name) | |
| x1 = res_layer(x1) | |
| x2 = res_layer(x2) | |
| x1, x2 = self.ccs[i](x1, x2) | |
| if i in self.out_indices: | |
| outs.append(torch.cat([x1, x2], dim=1)) | |
| return tuple(outs) | |
| class IA_ResNetV1c(IA_ResNet): | |
| """ResNetV1c variant described in [1]_. | |
| Compared with default ResNet(ResNetV1b), ResNetV1c replaces the 7x7 conv in | |
| the input stem with three 3x3 convs. For more details please refer to `Bag | |
| of Tricks for Image Classification with Convolutional Neural Networks | |
| <https://arxiv.org/abs/1812.01187>`_. | |
| """ | |
| def __init__(self, **kwargs): | |
| super(IA_ResNetV1c, self).__init__( | |
| deep_stem=True, avg_down=False, **kwargs) | |
| class IA_ResNetV1d(IA_ResNet): | |
| """ResNetV1d variant described in [1]_. | |
| Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in | |
| the input stem with three 3x3 convs. And in the downsampling block, a 2x2 | |
| avg_pool with stride 2 is added before conv, whose stride is changed to 1. | |
| """ | |
| def __init__(self, **kwargs): | |
| super(IA_ResNetV1d, self).__init__( | |
| deep_stem=True, avg_down=True, **kwargs) |