| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from .revcol_module import ConvNextBlock, LayerNorm, UpSampleConvnext |
| from mmdet.utils import get_root_logger |
| from ..builder import BACKBONES |
| from .revcol_function import ReverseFunction |
| from mmcv.cnn import constant_init, trunc_normal_init |
| from mmcv.runner import BaseModule, _load_checkpoint |
| from torch.utils.checkpoint import checkpoint |
|
|
| class Fusion(nn.Module): |
| def __init__(self, level, channels, first_col) -> None: |
| super().__init__() |
| |
| self.level = level |
| self.first_col = first_col |
| self.down = nn.Sequential( |
| nn.Conv2d(channels[level-1], channels[level], kernel_size=2, stride=2), |
| LayerNorm(channels[level], eps=1e-6, data_format="channels_first"), |
| ) if level in [1, 2, 3] else nn.Identity() |
| if not first_col: |
| self.up = UpSampleConvnext(1, channels[level+1], channels[level]) if level in [0, 1, 2] else nn.Identity() |
| |
|
|
| def forward(self, *args): |
| c_down, c_up = args |
| |
| if self.first_col: |
| x = self.down(c_down) |
| return x |
| |
| if self.level == 3: |
| x = self.down(c_down) |
| else: |
| x = self.up(c_up) + self.down(c_down) |
| return x |
|
|
| class Level(nn.Module): |
| def __init__(self, level, channels, layers, kernel_size, first_col, dp_rate=0.0) -> None: |
| super().__init__() |
| countlayer = sum(layers[:level]) |
| expansion = 4 |
| self.fusion = Fusion(level, channels, first_col) |
| modules = [ConvNextBlock(channels[level], expansion*channels[level], channels[level], kernel_size = kernel_size, layer_scale_init_value=1e-6, drop_path=dp_rate[countlayer+i]) for i in range(layers[level])] |
| self.blocks = nn.Sequential(*modules) |
| def forward(self, *args): |
| x = self.fusion(*args) |
| x = self.blocks(x) |
| return x |
|
|
| class SubNet(nn.Module): |
| def __init__(self, channels, layers, kernel_size, first_col, dp_rates, save_memory) -> None: |
| super().__init__() |
| shortcut_scale_init_value = 0.5 |
| self.save_memory = save_memory |
| self.alpha0 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[0], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha1 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[1], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha2 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[2], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
| self.alpha3 = nn.Parameter(shortcut_scale_init_value * torch.ones((1, channels[3], 1, 1)), |
| requires_grad=True) if shortcut_scale_init_value > 0 else None |
|
|
| self.level0 = Level(0, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| self.level1 = Level(1, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| self.level2 = Level(2, channels, layers, kernel_size,first_col, dp_rates) |
|
|
| self.level3 = Level(3, channels, layers, kernel_size, first_col, dp_rates) |
|
|
| def _forward_nonreverse(self, *args): |
| x, c0, c1, c2, c3= args |
|
|
| c0 = (self.alpha0)*c0 + self.level0(x, c1) |
| c1 = (self.alpha1)*c1 + self.level1(c0, c2) |
| c2 = (self.alpha2)*c2 + self.level2(c1, c3) |
| c3 = (self.alpha3)*c3 + self.level3(c2, None) |
|
|
| return c0, c1, c2, c3 |
|
|
| def _forward_reverse(self, *args): |
|
|
| local_funs = [self.level0, self.level1, self.level2, self.level3] |
| alpha = [self.alpha0, self.alpha1, self.alpha2, self.alpha3] |
| _, c0, c1, c2, c3 = ReverseFunction.apply( |
| local_funs, alpha, *args) |
|
|
| return c0, c1, c2, c3 |
|
|
| def forward(self, *args): |
| |
| self._clamp_abs(self.alpha0.data, 1e-3) |
| self._clamp_abs(self.alpha1.data, 1e-3) |
| self._clamp_abs(self.alpha2.data, 1e-3) |
| self._clamp_abs(self.alpha3.data, 1e-3) |
| |
| if self.save_memory: |
| return self._forward_reverse(*args) |
| else: |
| return self._forward_nonreverse(*args) |
|
|
| def _clamp_abs(self, data, value): |
| with torch.no_grad(): |
| sign=data.sign() |
| data.abs_().clamp_(value) |
| data*=sign |
|
|
| @BACKBONES.register_module() |
| class RevCol(BaseModule): |
| def __init__(self, channels=[32, 64, 96, 128], layers=[2, 3, 6, 3], num_subnet=5, kernel_size = 3, num_classes=1000, drop_path = 0.0, save_memory=True, single_head=True, out_indices=[0, 1, 2, 3], init_cfg=None) -> None: |
| super().__init__(init_cfg) |
| self.num_subnet = num_subnet |
| self.single_head = single_head |
| self.out_indices = out_indices |
| self.init_cfg = init_cfg |
|
|
| self.stem = nn.Sequential( |
| nn.Conv2d(3, channels[0], kernel_size=4, stride=4), |
| LayerNorm(channels[0], eps=1e-6, data_format="channels_first") |
| ) |
|
|
| |
|
|
| dp_rate = [x.item() for x in torch.linspace(0, drop_path, sum(layers))] |
| for i in range(num_subnet): |
| first_col = True if i == 0 else False |
| self.add_module(f'subnet{str(i)}', SubNet( |
| channels,layers, kernel_size, first_col, dp_rates=dp_rate, save_memory=save_memory)) |
|
|
| def init_weights(self): |
| logger = get_root_logger() |
| if self.init_cfg is None: |
| logger.warn(f'No pre-trained weights for ' |
| f'{self.__class__.__name__}, ' |
| f'training start from scratch') |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| trunc_normal_init(m, std=.02, bias=0.) |
| elif isinstance(m, nn.LayerNorm): |
| constant_init(m, 1.0) |
| else: |
| assert 'checkpoint' in self.init_cfg, f'Only support ' \ |
| f'specify `Pretrained` in ' \ |
| f'`init_cfg` in ' \ |
| f'{self.__class__.__name__} ' |
| ckpt = _load_checkpoint( |
| self.init_cfg.checkpoint, logger=logger, map_location='cpu') |
| if 'state_dict' in ckpt: |
| _state_dict = ckpt['state_dict'] |
| elif 'model' in ckpt: |
| _state_dict = ckpt['model'] |
| else: |
| _state_dict = ckpt |
| |
|
|
| state_dict = _state_dict |
| |
| |
| if list(state_dict.keys())[0].startswith('module.'): |
| state_dict = {k[7:]: v for k, v in state_dict.items()} |
|
|
| |
| self.load_state_dict(state_dict, False) |
| |
|
|
| def forward(self, x): |
| x = self.stem(x) |
| c0, c1, c2, c3 = 0, 0, 0, 0 |
| for i in range(self.num_subnet): |
| |
| c0, c1, c2, c3 = getattr(self, f'subnet{str(i)}')(x, c0, c1, c2, c3) |
| return c0, c1, c2, c3 |
| |
| def cal_dp_rate(self, depth, num_subnet, drop_path): |
| dp = np.zeros((depth, num_subnet)) |
| dp[:,0]=np.linspace(0, depth-1, depth) |
| dp[0,:]=np.linspace(0, num_subnet-1, num_subnet) |
| for i in range(1, depth): |
| for j in range(1, num_subnet): |
| dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1 |
| ratio = dp[-1][-1]/drop_path |
| dp_matrix = dp/ratio |
| return dp_matrix |
|
|