CR-Net / models /networks /dstream.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import torch
import torch.nn.functional as F
import torch.nn as nn
from models.networks.base_network import BaseNetwork
from models.networks.DenseArchitecture import _DenseBlock, _Transition
from collections import OrderedDict
import torch.nn.utils.spectral_norm as spectral_norm
# Content/style stream.
# The two streams are symmetrical with the same network structure,
# aiming at extracting corresponding feature representations in different levels.
class Stream(BaseNetwork):
def __init__(self, opt):
super().__init__()
self.opt = opt
growth_rate = 32
bn_size = 4
drop_rate = 0.0
block_config = [2, 4, 8, 16, 16]
num_init_features = 64
self.features = nn.ModuleList()
# First convolution
self.block0 = nn.Sequential(
OrderedDict(
[
("conv0", spectral_norm(nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False))),
# ("norm0", nn.InstanceNorm2d(num_init_features)),
("relu0", nn.ReLU(inplace=True)),
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock
num_features = num_init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
num_layers=num_layers,
num_input_features=num_features,
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=False,
)
self.features.add_module("denseblock%d" % (i + 1), block)
# print(f"{num_features} + {num_layers} * {growth_rate} = {num_features + num_layers * growth_rate}")
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module("transition%d" % (i + 1), trans)
num_features = num_features // 2
# Final instance norm
# self.features.add_module("norm5", nn.InstanceNorm2d(num_features))
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0)
def down(self, input):
return F.interpolate(input, scale_factor=0.5)
def forward(self,input):
x0 = self.block0(input)
x1 = self.features.denseblock1(x0)
x1 = self.features.transition1(x1)
x2 = self.features.denseblock2(x1)
x2 = self.features.transition2(x2)
x3 = self.features.denseblock3(x2)
x3 = self.features.transition3(x3)
x4 = self.features.denseblock4(x3)
x4 = self.features.transition4(x4)
x5 = self.features.denseblock5(x4)
x5 = self.down(x5)
return [x0, x1, x2, x3, x4, x5]
# Test above architecture with random input.
if __name__ == "__main__":
import torch
from options.train_options import TrainOptions
# parse options
opt = TrainOptions().parse()
stream = Stream(opt=opt)
input = torch.randn(1, 3, 512, 512)
x0, x1, x2, x3, x4, x5 = stream(input)
print(x0.shape, x1.shape, x2.shape, x3.shape, x4.shape, x5.shape)