| | import torch.nn as nn
|
| |
|
| | from .common import ResBlock, default_conv
|
| |
|
| | def encoder(in_channels, n_feats):
|
| | """RGB / IR feature encoder
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return nn.Sequential(
|
| | nn.Conv2d(in_channels, 1 * n_feats, 5, stride=1, padding=2),
|
| | nn.Conv2d(1 * n_feats, 2 * n_feats, 5, stride=2, padding=2),
|
| | nn.Conv2d(2 * n_feats, 3 * n_feats, 5, stride=2, padding=2),
|
| | )
|
| |
|
| | def decoder(out_channels, n_feats):
|
| | """RGB / IR / Depth decoder
|
| | """
|
| |
|
| |
|
| |
|
| | deconv_kargs = {'stride': 2, 'padding': 1, 'output_padding': 1}
|
| |
|
| | return nn.Sequential(
|
| | nn.ConvTranspose2d(3 * n_feats, 2 * n_feats, 3, **deconv_kargs),
|
| | nn.ConvTranspose2d(2 * n_feats, 1 * n_feats, 3, **deconv_kargs),
|
| | nn.Conv2d(n_feats, out_channels, 5, stride=1, padding=2),
|
| | )
|
| |
|
| |
|
| | def ResNet(n_feats, kernel_size, n_blocks, in_channels=None, out_channels=None):
|
| | """sequential ResNet
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | m = []
|
| |
|
| | if in_channels is not None:
|
| | m += [default_conv(in_channels, n_feats, kernel_size)]
|
| |
|
| | m += [ResBlock(n_feats, 3)] * n_blocks
|
| |
|
| | if out_channels is not None:
|
| | m += [default_conv(n_feats, out_channels, kernel_size)]
|
| |
|
| |
|
| | return nn.Sequential(*m)
|
| |
|
| |
|