CR-Net / models /networks /style_encoder.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import torch
import torch.nn as nn
from models.networks.base_network import BaseNetwork
import functools
class SimpleStyleEncoder(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
kw = 4
padw = 1
use_bias_for_conv = True
norm_layers_list = []
if opt.norm_SE == 'instance':
norm_layer_class = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
use_bias_for_conv = True
elif opt.norm_SE == 'batch':
norm_layer_class = functools.partial(nn.BatchNorm2d, affine=True)
use_bias_for_conv = False
elif opt.norm_SE == 'none' or opt.norm_SE is None:
norm_layer_class = None
else:
raise NotImplementedError(f"Normalization layer {opt.norm_SE} is not implemented for SimpleStyleEncoder")
sequence = [
nn.Conv2d(opt.style_enc_nc, opt.style_enc_nef, kernel_size=kw, stride=2, padding=padw,
bias=use_bias_for_conv),
]
if norm_layer_class: sequence.append(norm_layer_class(opt.style_enc_nef))
sequence.append(nn.ReLU(True))
nf_mult = 1
for n in range(1, opt.style_enc_n_downsampling):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence.append(
nn.Conv2d(opt.style_enc_nef * nf_mult_prev, opt.style_enc_nef * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias_for_conv)
)
if norm_layer_class: sequence.append(norm_layer_class(opt.style_enc_nef * nf_mult))
sequence.append(nn.ReLU(True))
final_channels = opt.style_enc_nef * nf_mult
sequence += [
nn.Conv2d(final_channels, final_channels, kernel_size=3, stride=1, padding=1, bias=use_bias_for_conv),
]
if norm_layer_class: sequence.append(norm_layer_class(final_channels))
sequence.append(nn.ReLU(True))
sequence += [nn.AdaptiveAvgPool2d(1)] # Output: [B, final_channels, 1, 1]
self.cnn_body = nn.Sequential(*sequence)
self.fc_out = nn.Linear(final_channels, opt.style_enc_latent_dim) # opt.style_enc_latent_dim
def forward(self, input_image):
features = self.cnn_body(input_image)
features_flat = features.view(features.size(0), -1) # Flatten
latent_code = self.fc_out(features_flat)
return latent_code