| | import torch.nn as nn |
| |
|
| | vgg19_cfg = [3, 64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"] |
| | decoder_cfg = [512, 256, "U", 256, 256, 256, 128, "U", 128, 64, 'U', 64, 3] |
| |
|
| | def vgg19(weights=None): |
| | """ |
| | Build vgg19 network. Load weights if weights are given. |
| | |
| | Args: |
| | weights (dict): vgg19 pretrained weights |
| | |
| | Return: |
| | layers (nn.Sequential): vgg19 layers |
| | """ |
| |
|
| | modules = make_block(vgg19_cfg) |
| | modules = [nn.Conv2d(3, 3, kernel_size=1)] + list(modules.children()) |
| | layers = nn.Sequential(*modules) |
| |
|
| | if weights: |
| | layers.load_state_dict(weights) |
| | |
| | return layers |
| |
|
| |
|
| | def decoder(weights=None): |
| | """ |
| | Build decoder network. Load weights if weights are given. |
| | |
| | Args: |
| | weights (dict): decoder pretrained weights |
| | |
| | Return: |
| | layers (nn.Sequential): decoder layers |
| | """ |
| |
|
| | modules = make_block(decoder_cfg) |
| | layers = nn.Sequential(*list(modules.children())[:-1]) |
| |
|
| | if weights: |
| | layers.load_state_dict(weights) |
| |
|
| | return layers |
| |
|
| |
|
| | def make_block(config): |
| | """ |
| | Helper function for building blocks of convolutional layers. |
| | |
| | Args: |
| | config (list): List of layer configs. "M" |
| | "M" - Max pooling layer. |
| | "U" - Upsampling layer. |
| | i (int) - Convolutional layer (i filters) plus ReLU activation. |
| | Return: |
| | layers (nn.Sequential): block layers |
| | """ |
| | layers = [] |
| | in_channels = config[0] |
| | |
| | for c in config[1:]: |
| | if c == "M": |
| | layers.append(nn.MaxPool2d(kernel_size=2, stride=2, padding=0)) |
| | elif c == "U": |
| | layers.append(nn.Upsample(scale_factor=2, mode='nearest')) |
| | else: |
| | assert(isinstance(c, int)) |
| | layers.append(nn.Conv2d(in_channels, c, kernel_size=3, padding=1)) |
| | layers.append(nn.ReLU(inplace=True)) |
| | in_channels = c |
| |
|
| | return nn.Sequential(*layers) |
| |
|