CR-Net / models /networks /rafael_generator.py
datnguyentien204's picture
Upload 147 files
0f52c9d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math
import random
from models.networks.base_network import BaseNetwork
from models.networks.rafael_generator_module.rafael import RAFAEL as RAFAEL_Content_Encoder
from .time_utils import get_day_night_weights
class AdaINLayer(nn.Module):
def __init__(self):
super().__init__()
self.instance_norm = None
def forward(self, content_feat, style_gamma, style_beta):
if content_feat is None or style_gamma is None or style_beta is None: return content_feat
if content_feat.numel() == 0: return content_feat
if self.instance_norm is None or self.instance_norm.num_features != content_feat.size(1):
if content_feat.size(1) == 0: return content_feat
self.instance_norm = nn.InstanceNorm2d(content_feat.size(1), affine=False, track_running_stats=False).to(
content_feat.device)
normalized_content = self.instance_norm(content_feat)
return style_gamma * normalized_content + style_beta
class AdaINResBlockWithLatent(nn.Module):
def __init__(self, channels, latent_dim, use_reflection_pad=True):
super().__init__()
self.use_reflection_pad = use_reflection_pad
self.mlp_style1 = nn.Linear(latent_dim, channels * 2)
self.pad1 = nn.ReflectionPad2d(1) if use_reflection_pad else nn.Identity()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=0)
self.relu1 = nn.ReLU()
self.mlp_style2 = nn.Linear(latent_dim, channels * 2)
self.pad2 = nn.ReflectionPad2d(1) if use_reflection_pad else nn.Identity()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=0)
self.adain1 = AdaINLayer()
self.adain2 = AdaINLayer()
def forward(self, x, z_style):
skip_connection = x
style_params1 = self.mlp_style1(z_style)
gamma1 = style_params1[:, :x.size(1)].unsqueeze(-1).unsqueeze(-1)
beta1 = style_params1[:, x.size(1):].unsqueeze(-1).unsqueeze(-1)
stylized_x = self.adain1(x, gamma1, beta1)
h = self.pad1(stylized_x)
h = self.relu1(self.conv1(h))
style_params2 = self.mlp_style2(z_style)
gamma2 = style_params2[:, :h.size(1)].unsqueeze(-1).unsqueeze(-1)
beta2 = style_params2[:, h.size(1):].unsqueeze(-1).unsqueeze(-1)
stylized_h = self.adain2(h, gamma2, beta2)
h = self.pad2(stylized_h)
h = self.conv2(h)
return F.relu(h + skip_connection)
class AdaINSynthesisNetworkWithLatent(BaseNetwork):
def __init__(self, opt, content_encoder_output_dim, initial_content_spatial_res, style_latent_dim):
super().__init__()
self.opt = opt
current_channels = content_encoder_output_dim
self.initial_conv = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(content_encoder_output_dim, current_channels, kernel_size=3,
padding=0), nn.ReLU())
res_blocks_list = []
for _ in range(opt.synth_num_res_blocks): res_blocks_list.append(
AdaINResBlockWithLatent(current_channels, style_latent_dim))
self.res_blocks = nn.ModuleList(res_blocks_list)
upsample_blocks_list = []
target_h, current_h = opt.rafael_img_size, initial_content_spatial_res[0]
num_upsamples_needed = 3
if current_h > 0 and target_h > current_h:
ratio = target_h / current_h
if ratio == int(ratio) and (int(ratio) & (int(ratio) - 1)) == 0: num_upsamples_needed = int(
math.log2(ratio))
for i in range(num_upsamples_needed):
out_channels = current_channels // 2 if current_channels > opt.ngf * 2 else opt.ngf
if out_channels < opt.ngf and opt.ngf > 0: out_channels = opt.ngf
if out_channels <= 0: out_channels = opt.ngf if opt.ngf > 0 else 3
upsample_blocks_list.extend([nn.Upsample(scale_factor=2, mode='nearest'), nn.ReflectionPad2d(1),
nn.Conv2d(current_channels, out_channels, kernel_size=3, padding=0), nn.ReLU(),
AdaINResBlockWithLatent(out_channels, style_latent_dim)])
current_channels = out_channels
self.upsample_blocks = nn.ModuleList(upsample_blocks_list)
self.final_conv = nn.Sequential(nn.ReflectionPad2d(3), nn.Conv2d(current_channels, 3, kernel_size=7, padding=0),
nn.Tanh())
def forward(self, content_features_seq, content_spatial_res_tuple, z_style1, z_style2=None, mixing_prob=0.0,
arbitrary_input=False):
B, L, C_enc = content_features_seq.shape
H_feat, W_feat = content_spatial_res_tuple
x = content_features_seq.permute(0, 2, 1).contiguous().view(B, C_enc, H_feat, W_feat)
x = self.initial_conv(x)
num_adain_blocks = len(self.res_blocks) + len(
[op for op in self.upsample_blocks if isinstance(op, AdaINResBlockWithLatent)])
mixing_point = num_adain_blocks + 1
if z_style2 is not None and mixing_prob > 0 and random.random() < mixing_prob: mixing_point = torch.randint(1,
num_adain_blocks,
(
1,)).item()
current_adain_block_idx = 0
for block in self.res_blocks:
z_current = z_style2 if current_adain_block_idx >= mixing_point else z_style1
x = block(x, z_current)
current_adain_block_idx += 1
upsample_op_idx = 0
while upsample_op_idx < len(self.upsample_blocks):
x = self.upsample_blocks[upsample_op_idx](x);
upsample_op_idx += 1
x = self.upsample_blocks[upsample_op_idx](x);
upsample_op_idx += 1
x = self.upsample_blocks[upsample_op_idx](x);
upsample_op_idx += 1
x = self.upsample_blocks[upsample_op_idx](x);
upsample_op_idx += 1
res_block = self.upsample_blocks[upsample_op_idx]
z_current = z_style2 if current_adain_block_idx >= mixing_point else z_style1
x = res_block(x, z_current)
upsample_op_idx += 1
current_adain_block_idx += 1
return self.final_conv(x)
class RafaelGenerator(BaseNetwork):
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument('--rafael_img_size', type=int, default=256)
parser.add_argument('--rafael_patch_size', type=int, default=2)
parser.add_argument('--rafael_embed_dim', type=int, default=192)
parser.add_argument('--rafael_depths', type=str, default='2,2,2')
parser.add_argument('--rafael_nheads', type=str, default='3,6,12')
parser.add_argument('--rafael_strip_widths', type=str, default='2,4,7')
parser.add_argument('--rafael_use_checkpoint', action='store_true')
parser.add_argument('--style_enc_latent_dim', type=int, default=256, help='Output latent z dimension')
parser.add_argument('--synth_num_res_blocks', type=int, default=4)
if is_train:
parser.add_argument('--style_mixing_prob', type=float, default=0.9)
parser.add_argument('--light_map_channels', type=int, default=1,
help='Số kênh cho bản đồ hướng sáng (1 hoặc 2).')
parser.add_argument('--add_phi_to_latent', action='store_true', help='Thêm (cos(phi), sin(phi)) vào z_style')
return parser
def __init__(self, opt):
super().__init__()
self.opt = opt
rafael_depths = [int(d) for d in opt.rafael_depths.split(',') if d.strip()]
rafael_nheads = [int(h) for h in opt.rafael_nheads.split(',') if h.strip()]
rafael_strip_widths_str = opt.rafael_strip_widths.split(',')
if len(rafael_strip_widths_str) == 1 and rafael_strip_widths_str[0].strip():
rafael_strip_widths = [int(rafael_strip_widths_str[0])] * len(rafael_depths)
else:
rafael_strip_widths = [int(sw) for sw in rafael_strip_widths_str if sw.strip()]
content_encoder_in_chans = 3 + opt.light_map_channels
self.rafael_content_encoder = RAFAEL_Content_Encoder(
img_size=opt.rafael_img_size, patch_size=opt.rafael_patch_size,
in_chans=content_encoder_in_chans,
embed_dim=opt.rafael_embed_dim, depths=rafael_depths, nhead=rafael_nheads,
strip_width=rafael_strip_widths, use_checkpoint=opt.rafael_use_checkpoint
)
num_downsamples_in_enc = len(rafael_depths) - 1
content_encoder_output_dim = opt.rafael_embed_dim * (2 ** num_downsamples_in_enc)
initial_spatial_res_h = opt.rafael_img_size // opt.rafael_patch_size
initial_spatial_res_w = opt.rafael_img_size // opt.rafael_patch_size
self.content_encoder_spatial_res = (
initial_spatial_res_h // (2 ** num_downsamples_in_enc),
initial_spatial_res_w // (2 ** num_downsamples_in_enc)
)
self.latent_dim = opt.style_enc_latent_dim
synthesis_latent_dim = self.latent_dim + 2 if opt.add_phi_to_latent else self.latent_dim
self.z_day = nn.Parameter(torch.randn(1, self.latent_dim))
self.z_night = nn.Parameter(torch.randn(1, self.latent_dim))
self.synthesis_network = AdaINSynthesisNetworkWithLatent(
opt=opt,
content_encoder_output_dim=content_encoder_output_dim,
initial_content_spatial_res=self.content_encoder_spatial_res,
style_latent_dim=synthesis_latent_dim
)
def forward(self, content_image_with_light, w_day, w_night, cos_phi, sin_phi, phi2=None, arbitrary_input=False):
current_device = content_image_with_light.device
batch_size = content_image_with_light.size(0)
f_c_tuple = self.rafael_content_encoder(content_image_with_light, arbitrary_input=arbitrary_input)
content_features_seq, content_spatial_res = f_c_tuple[0], f_c_tuple[2]
z_day_learned = self.z_day.expand(batch_size, -1)
z_night_learned = self.z_night.expand(batch_size, -1)
w_day1_expanded = w_day.view(-1, 1).expand_as(z_day_learned)
w_night1_expanded = w_night.view(-1, 1).expand_as(z_night_learned)
z_style1_base = w_day1_expanded * z_day_learned + w_night1_expanded * z_night_learned
if self.opt.add_phi_to_latent:
z_style1 = torch.cat([z_style1_base, cos_phi.view(-1, 1), sin_phi.view(-1, 1)], dim=1)
else:
z_style1 = z_style1_base
z_style2 = None
if phi2 is not None and self.training:
phi2_tensor = phi2.to(current_device, dtype=z_day_learned.dtype)
w_day2, w_night2 = get_day_night_weights(phi2_tensor, self.opt.daylight_curve_steepness)
w_day2_expanded = w_day2.view(-1, 1).expand_as(z_day_learned)
w_night2_expanded = w_night2.view(-1, 1).expand_as(z_night_learned)
z_style2_base = w_day2_expanded * z_day_learned + w_night2_expanded * z_night_learned
if self.opt.add_phi_to_latent:
cos_phi2, sin_phi2 = torch.cos(phi2_tensor), torch.sin(phi2_tensor)
z_style2 = torch.cat([z_style2_base, cos_phi2.view(-1, 1), sin_phi2.view(-1, 1)], dim=1)
else:
z_style2 = z_style2_base
mixing_prob = self.opt.style_mixing_prob if self.training else 0.0
i_cs = self.synthesis_network(content_features_seq, content_spatial_res, z_style1, z_style2, mixing_prob,
arbitrary_input=arbitrary_input)
generator_returns = {}
if self.training:
z_day_for_identity = z_day_learned
z_night_for_identity = z_night_learned
if self.opt.add_phi_to_latent:
day_phi_components = torch.tensor([[-1.0, 0.0]], device=current_device).expand(batch_size, -1)
night_phi_components = torch.tensor([[1.0, 0.0]], device=current_device).expand(batch_size, -1)
z_day_for_identity = torch.cat([z_day_learned, day_phi_components], dim=1)
z_night_for_identity = torch.cat([z_night_learned, night_phi_components], dim=1)
i_cc = self.synthesis_network(content_features_seq, content_spatial_res, z_day_for_identity,
mixing_prob=0.0, arbitrary_input=arbitrary_input)
i_nn = self.synthesis_network(content_features_seq, content_spatial_res, z_night_for_identity,
mixing_prob=0.0, arbitrary_input=arbitrary_input)
generator_returns['identity_content'] = i_cc
generator_returns['identity_night'] = i_nn
return i_cs, generator_returns