| | import pytest |
| | import torch |
| | from mmcv.cnn import ConvModule |
| | from mmcv.utils.parrots_wrapper import _BatchNorm |
| | from torch import nn |
| |
|
| | from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, |
| | InterpConv, UNet, UpConvBlock) |
| |
|
| |
|
| | def check_norm_state(modules, train_state): |
| | """Check if norm layer is in correct train state.""" |
| | for mod in modules: |
| | if isinstance(mod, _BatchNorm): |
| | if mod.training != train_state: |
| | return False |
| | return True |
| |
|
| |
|
| | def test_unet_basic_conv_block(): |
| | with pytest.raises(AssertionError): |
| | |
| | dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| | BasicConvBlock(64, 64, dcn=dcn) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict(type='ContextBlock', ratio=1. / 16), |
| | position='after_conv3') |
| | ] |
| | BasicConvBlock(64, 64, plugins=plugins) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict( |
| | type='GeneralizedAttention', |
| | spatial_range=-1, |
| | num_heads=8, |
| | attention_type='0010', |
| | kv_stride=2), |
| | position='after_conv2') |
| | ] |
| | BasicConvBlock(64, 64, plugins=plugins) |
| |
|
| | |
| | block = BasicConvBlock(16, 16, with_cp=True) |
| | assert block.with_cp |
| | x = torch.randn(1, 16, 64, 64, requires_grad=True) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 16, 64, 64]) |
| |
|
| | block = BasicConvBlock(16, 16, with_cp=False) |
| | assert not block.with_cp |
| | x = torch.randn(1, 16, 64, 64) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 16, 64, 64]) |
| |
|
| | |
| | block = BasicConvBlock(16, 16, stride=2) |
| | x = torch.randn(1, 16, 64, 64) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 16, 32, 32]) |
| |
|
| | |
| | block = BasicConvBlock(16, 64, num_convs=3, dilation=3) |
| | assert block.convs[0].conv.in_channels == 16 |
| | assert block.convs[0].conv.out_channels == 64 |
| | assert block.convs[0].conv.kernel_size == (3, 3) |
| | assert block.convs[0].conv.dilation == (1, 1) |
| | assert block.convs[0].conv.padding == (1, 1) |
| |
|
| | assert block.convs[1].conv.in_channels == 64 |
| | assert block.convs[1].conv.out_channels == 64 |
| | assert block.convs[1].conv.kernel_size == (3, 3) |
| | assert block.convs[1].conv.dilation == (3, 3) |
| | assert block.convs[1].conv.padding == (3, 3) |
| |
|
| | assert block.convs[2].conv.in_channels == 64 |
| | assert block.convs[2].conv.out_channels == 64 |
| | assert block.convs[2].conv.kernel_size == (3, 3) |
| | assert block.convs[2].conv.dilation == (3, 3) |
| | assert block.convs[2].conv.padding == (3, 3) |
| |
|
| |
|
| | def test_deconv_module(): |
| | with pytest.raises(AssertionError): |
| | |
| | |
| | DeconvModule(64, 32, kernel_size=1, scale_factor=2) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | DeconvModule(64, 32, kernel_size=3, scale_factor=2) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | DeconvModule(64, 32, kernel_size=5, scale_factor=4) |
| |
|
| | |
| | block = DeconvModule(64, 32, with_cp=True) |
| | assert block.with_cp |
| | x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | block = DeconvModule(64, 32, with_cp=False) |
| | assert not block.with_cp |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | x = torch.randn(1, 64, 64, 64) |
| | block = DeconvModule(64, 32, kernel_size=2, scale_factor=2) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 128, 128]) |
| |
|
| | block = DeconvModule(64, 32, kernel_size=6, scale_factor=2) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 128, 128]) |
| |
|
| | |
| | x = torch.randn(1, 64, 64, 64) |
| | block = DeconvModule(64, 32, kernel_size=4, scale_factor=4) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | block = DeconvModule(64, 32, kernel_size=6, scale_factor=4) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| |
|
| | def test_interp_conv(): |
| | |
| | block = InterpConv(64, 32, with_cp=True) |
| | assert block.with_cp |
| | x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | block = InterpConv(64, 32, with_cp=False) |
| | assert not block.with_cp |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | block = InterpConv(64, 32, conv_first=False) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert isinstance(block.interp_upsample[0], nn.Upsample) |
| | assert isinstance(block.interp_upsample[1], ConvModule) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | block = InterpConv(64, 32, conv_first=True) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert isinstance(block.interp_upsample[0], ConvModule) |
| | assert isinstance(block.interp_upsample[1], nn.Upsample) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | block = InterpConv( |
| | 64, |
| | 32, |
| | conv_first=False, |
| | upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False)) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert isinstance(block.interp_upsample[0], nn.Upsample) |
| | assert isinstance(block.interp_upsample[1], ConvModule) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| | assert block.interp_upsample[0].mode == 'bilinear' |
| |
|
| | |
| | block = InterpConv( |
| | 64, |
| | 32, |
| | conv_first=False, |
| | upsampe_cfg=dict(scale_factor=2, mode='nearest')) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(x) |
| | assert isinstance(block.interp_upsample[0], nn.Upsample) |
| | assert isinstance(block.interp_upsample[1], ConvModule) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| | assert block.interp_upsample[0].mode == 'nearest' |
| |
|
| |
|
| | def test_up_conv_block(): |
| | with pytest.raises(AssertionError): |
| | |
| | dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| | UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict(type='ContextBlock', ratio=1. / 16), |
| | position='after_conv3') |
| | ] |
| | UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict( |
| | type='GeneralizedAttention', |
| | spatial_range=-1, |
| | num_heads=8, |
| | attention_type='0010', |
| | kv_stride=2), |
| | position='after_conv2') |
| | ] |
| | UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
| |
|
| | |
| | block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True) |
| | skip_x = torch.randn(1, 32, 256, 256, requires_grad=True) |
| | x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | |
| | block = UpConvBlock( |
| | BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv')) |
| | skip_x = torch.randn(1, 32, 256, 256) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | |
| | block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None) |
| | skip_x = torch.randn(1, 32, 256, 256) |
| | x = torch.randn(1, 64, 256, 256) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | |
| | block = UpConvBlock( |
| | BasicConvBlock, |
| | 64, |
| | 32, |
| | 32, |
| | upsample_cfg=dict( |
| | type='InterpConv', |
| | upsampe_cfg=dict( |
| | scale_factor=2, mode='bilinear', align_corners=False))) |
| | skip_x = torch.randn(1, 32, 256, 256) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | |
| | block = UpConvBlock( |
| | BasicConvBlock, |
| | 64, |
| | 32, |
| | 32, |
| | upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2)) |
| | skip_x = torch.randn(1, 32, 256, 256) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | |
| | block = UpConvBlock( |
| | conv_block=BasicConvBlock, |
| | in_channels=64, |
| | skip_channels=32, |
| | out_channels=32, |
| | num_convs=3, |
| | dilation=3, |
| | upsample_cfg=dict( |
| | type='InterpConv', |
| | upsampe_cfg=dict( |
| | scale_factor=2, mode='bilinear', align_corners=False))) |
| | skip_x = torch.randn(1, 32, 256, 256) |
| | x = torch.randn(1, 64, 128, 128) |
| | x_out = block(skip_x, x) |
| | assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| |
|
| | assert block.conv_block.convs[0].conv.in_channels == 64 |
| | assert block.conv_block.convs[0].conv.out_channels == 32 |
| | assert block.conv_block.convs[0].conv.kernel_size == (3, 3) |
| | assert block.conv_block.convs[0].conv.dilation == (1, 1) |
| | assert block.conv_block.convs[0].conv.padding == (1, 1) |
| |
|
| | assert block.conv_block.convs[1].conv.in_channels == 32 |
| | assert block.conv_block.convs[1].conv.out_channels == 32 |
| | assert block.conv_block.convs[1].conv.kernel_size == (3, 3) |
| | assert block.conv_block.convs[1].conv.dilation == (3, 3) |
| | assert block.conv_block.convs[1].conv.padding == (3, 3) |
| |
|
| | assert block.conv_block.convs[2].conv.in_channels == 32 |
| | assert block.conv_block.convs[2].conv.out_channels == 32 |
| | assert block.conv_block.convs[2].conv.kernel_size == (3, 3) |
| | assert block.conv_block.convs[2].conv.dilation == (3, 3) |
| | assert block.conv_block.convs[2].conv.padding == (3, 3) |
| |
|
| | assert block.upsample.interp_upsample[1].conv.in_channels == 64 |
| | assert block.upsample.interp_upsample[1].conv.out_channels == 32 |
| | assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1) |
| | assert block.upsample.interp_upsample[1].conv.dilation == (1, 1) |
| | assert block.upsample.interp_upsample[1].conv.padding == (0, 0) |
| |
|
| |
|
| | def test_unet(): |
| | with pytest.raises(AssertionError): |
| | |
| | dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| | UNet(3, 64, 5, dcn=dcn) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict(type='ContextBlock', ratio=1. / 16), |
| | position='after_conv3') |
| | ] |
| | UNet(3, 64, 5, plugins=plugins) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | plugins = [ |
| | dict( |
| | cfg=dict( |
| | type='GeneralizedAttention', |
| | spatial_range=-1, |
| | num_heads=8, |
| | attention_type='0010', |
| | kv_stride=2), |
| | position='after_conv2') |
| | ] |
| | UNet(3, 64, 5, plugins=plugins) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=4, |
| | strides=(1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2), |
| | downsamples=(True, True, True), |
| | enc_dilations=(1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1)) |
| | x = torch.randn(2, 3, 65, 65) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 65, 65) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 65, 65) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 2, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 65, 65) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=6, |
| | strides=(1, 1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2, 2), |
| | downsamples=(True, True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 65, 65) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | with pytest.raises(AssertionError): |
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1, 1)) |
| | x = torch.randn(2, 3, 64, 64) |
| | unet(x) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1), |
| | norm_eval=True) |
| | unet.train() |
| | assert check_norm_state(unet.modules(), False) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1), |
| | norm_eval=False) |
| | unet.train() |
| | assert check_norm_state(unet.modules(), True) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 2, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, False, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 64, 64]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 64, 64]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 64, 64]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 1, 1, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(False, False, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| |
|
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 128, 128]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 128, 128]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 128, 128]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 128, 128]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, True), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | print(unet) |
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | print(unet) |
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 2, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, True, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | print(unet) |
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | print(unet) |
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|
| | |
| | unet = UNet( |
| | in_channels=3, |
| | base_channels=64, |
| | num_stages=5, |
| | strides=(1, 2, 2, 1, 1), |
| | enc_num_convs=(2, 2, 2, 2, 2), |
| | dec_num_convs=(2, 2, 2, 2), |
| | downsamples=(True, True, False, False), |
| | enc_dilations=(1, 1, 1, 1, 1), |
| | dec_dilations=(1, 1, 1, 1)) |
| | unet.init_weights(pretrained=None) |
| | print(unet) |
| | x = torch.randn(2, 3, 128, 128) |
| | x_outs = unet(x) |
| | assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| | assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| | assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| | assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| | assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
| |
|