Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import paddle | |
| import paddle.nn as nn | |
| from arch.base_module import MiddleNet, ResBlock | |
| from arch.encoder import Encoder | |
| from arch.decoder import Decoder, DecoderUnet, SingleDecoder | |
| from utils.load_params import load_dygraph_pretrain | |
| from utils.logging import get_logger | |
| class StyleTextRec(nn.Layer): | |
| def __init__(self, config): | |
| super(StyleTextRec, self).__init__() | |
| self.logger = get_logger() | |
| self.text_generator = TextGenerator(config["Predictor"][ | |
| "text_generator"]) | |
| self.bg_generator = BgGeneratorWithMask(config["Predictor"][ | |
| "bg_generator"]) | |
| self.fusion_generator = FusionGeneratorSimple(config["Predictor"][ | |
| "fusion_generator"]) | |
| bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"] | |
| text_generator_pretrain = config["Predictor"]["text_generator"][ | |
| "pretrain"] | |
| fusion_generator_pretrain = config["Predictor"]["fusion_generator"][ | |
| "pretrain"] | |
| load_dygraph_pretrain( | |
| self.bg_generator, | |
| self.logger, | |
| path=bg_generator_pretrain, | |
| load_static_weights=False) | |
| load_dygraph_pretrain( | |
| self.text_generator, | |
| self.logger, | |
| path=text_generator_pretrain, | |
| load_static_weights=False) | |
| load_dygraph_pretrain( | |
| self.fusion_generator, | |
| self.logger, | |
| path=fusion_generator_pretrain, | |
| load_static_weights=False) | |
| def forward(self, style_input, text_input): | |
| text_gen_output = self.text_generator.forward(style_input, text_input) | |
| fake_text = text_gen_output["fake_text"] | |
| fake_sk = text_gen_output["fake_sk"] | |
| bg_gen_output = self.bg_generator.forward(style_input) | |
| bg_encode_feature = bg_gen_output["bg_encode_feature"] | |
| bg_decode_feature1 = bg_gen_output["bg_decode_feature1"] | |
| bg_decode_feature2 = bg_gen_output["bg_decode_feature2"] | |
| fake_bg = bg_gen_output["fake_bg"] | |
| fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg) | |
| fake_fusion = fusion_gen_output["fake_fusion"] | |
| return { | |
| "fake_fusion": fake_fusion, | |
| "fake_text": fake_text, | |
| "fake_sk": fake_sk, | |
| "fake_bg": fake_bg, | |
| } | |
| class TextGenerator(nn.Layer): | |
| def __init__(self, config): | |
| super(TextGenerator, self).__init__() | |
| name = config["module_name"] | |
| encode_dim = config["encode_dim"] | |
| norm_layer = config["norm_layer"] | |
| conv_block_dropout = config["conv_block_dropout"] | |
| conv_block_num = config["conv_block_num"] | |
| conv_block_dilation = config["conv_block_dilation"] | |
| if norm_layer == "InstanceNorm2D": | |
| use_bias = True | |
| else: | |
| use_bias = False | |
| self.encoder_text = Encoder( | |
| name=name + "_encoder_text", | |
| in_channels=3, | |
| encode_dim=encode_dim, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation) | |
| self.encoder_style = Encoder( | |
| name=name + "_encoder_style", | |
| in_channels=3, | |
| encode_dim=encode_dim, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation) | |
| self.decoder_text = Decoder( | |
| name=name + "_decoder_text", | |
| encode_dim=encode_dim, | |
| out_channels=int(encode_dim / 2), | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation, | |
| out_conv_act="Tanh", | |
| out_conv_act_attr=None) | |
| self.decoder_sk = Decoder( | |
| name=name + "_decoder_sk", | |
| encode_dim=encode_dim, | |
| out_channels=1, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation, | |
| out_conv_act="Sigmoid", | |
| out_conv_act_attr=None) | |
| self.middle = MiddleNet( | |
| name=name + "_middle_net", | |
| in_channels=int(encode_dim / 2) + 1, | |
| mid_channels=encode_dim, | |
| out_channels=3, | |
| use_bias=use_bias) | |
| def forward(self, style_input, text_input): | |
| style_feature = self.encoder_style.forward(style_input)["res_blocks"] | |
| text_feature = self.encoder_text.forward(text_input)["res_blocks"] | |
| fake_c_temp = self.decoder_text.forward([text_feature, | |
| style_feature])["out_conv"] | |
| fake_sk = self.decoder_sk.forward([text_feature, | |
| style_feature])["out_conv"] | |
| fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1)) | |
| return {"fake_sk": fake_sk, "fake_text": fake_text} | |
| class BgGeneratorWithMask(nn.Layer): | |
| def __init__(self, config): | |
| super(BgGeneratorWithMask, self).__init__() | |
| name = config["module_name"] | |
| encode_dim = config["encode_dim"] | |
| norm_layer = config["norm_layer"] | |
| conv_block_dropout = config["conv_block_dropout"] | |
| conv_block_num = config["conv_block_num"] | |
| conv_block_dilation = config["conv_block_dilation"] | |
| self.output_factor = config.get("output_factor", 1.0) | |
| if norm_layer == "InstanceNorm2D": | |
| use_bias = True | |
| else: | |
| use_bias = False | |
| self.encoder_bg = Encoder( | |
| name=name + "_encoder_bg", | |
| in_channels=3, | |
| encode_dim=encode_dim, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation) | |
| self.decoder_bg = SingleDecoder( | |
| name=name + "_decoder_bg", | |
| encode_dim=encode_dim, | |
| out_channels=3, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation, | |
| out_conv_act="Tanh", | |
| out_conv_act_attr=None) | |
| self.decoder_mask = Decoder( | |
| name=name + "_decoder_mask", | |
| encode_dim=encode_dim // 2, | |
| out_channels=1, | |
| use_bias=use_bias, | |
| norm_layer=norm_layer, | |
| act="ReLU", | |
| act_attr=None, | |
| conv_block_dropout=conv_block_dropout, | |
| conv_block_num=conv_block_num, | |
| conv_block_dilation=conv_block_dilation, | |
| out_conv_act="Sigmoid", | |
| out_conv_act_attr=None) | |
| self.middle = MiddleNet( | |
| name=name + "_middle_net", | |
| in_channels=3 + 1, | |
| mid_channels=encode_dim, | |
| out_channels=3, | |
| use_bias=use_bias) | |
| def forward(self, style_input): | |
| encode_bg_output = self.encoder_bg(style_input) | |
| decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"], | |
| encode_bg_output["down2"], | |
| encode_bg_output["down1"]) | |
| fake_c_temp = decode_bg_output["out_conv"] | |
| fake_bg_mask = self.decoder_mask.forward(encode_bg_output[ | |
| "res_blocks"])["out_conv"] | |
| fake_bg = self.middle( | |
| paddle.concat( | |
| (fake_c_temp, fake_bg_mask), axis=1)) | |
| return { | |
| "bg_encode_feature": encode_bg_output["res_blocks"], | |
| "bg_decode_feature1": decode_bg_output["up1"], | |
| "bg_decode_feature2": decode_bg_output["up2"], | |
| "fake_bg": fake_bg, | |
| "fake_bg_mask": fake_bg_mask, | |
| } | |
| class FusionGeneratorSimple(nn.Layer): | |
| def __init__(self, config): | |
| super(FusionGeneratorSimple, self).__init__() | |
| name = config["module_name"] | |
| encode_dim = config["encode_dim"] | |
| norm_layer = config["norm_layer"] | |
| conv_block_dropout = config["conv_block_dropout"] | |
| conv_block_dilation = config["conv_block_dilation"] | |
| if norm_layer == "InstanceNorm2D": | |
| use_bias = True | |
| else: | |
| use_bias = False | |
| self._conv = nn.Conv2D( | |
| in_channels=6, | |
| out_channels=encode_dim, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| groups=1, | |
| weight_attr=paddle.ParamAttr(name=name + "_conv_weights"), | |
| bias_attr=False) | |
| self._res_block = ResBlock( | |
| name="{}_conv_block".format(name), | |
| channels=encode_dim, | |
| norm_layer=norm_layer, | |
| use_dropout=conv_block_dropout, | |
| use_dilation=conv_block_dilation, | |
| use_bias=use_bias) | |
| self._reduce_conv = nn.Conv2D( | |
| in_channels=encode_dim, | |
| out_channels=3, | |
| kernel_size=3, | |
| stride=1, | |
| padding=1, | |
| groups=1, | |
| weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"), | |
| bias_attr=False) | |
| def forward(self, fake_text, fake_bg): | |
| fake_concat = paddle.concat((fake_text, fake_bg), axis=1) | |
| fake_concat_tmp = self._conv(fake_concat) | |
| output_res = self._res_block(fake_concat_tmp) | |
| fake_fusion = self._reduce_conv(output_res) | |
| return {"fake_fusion": fake_fusion} | |