|
|
import os |
|
|
from abc import abstractmethod |
|
|
from typing import Dict |
|
|
from typing import Tuple |
|
|
|
|
|
import kornia |
|
|
import lpips |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from loguru import logger |
|
|
|
|
|
from Deep3DFaceRecon_pytorch.models.bfm import ParametricFaceModel |
|
|
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper |
|
|
from HRNet.hrnet import HighResolutionNet |
|
|
from arcface_torch.backbones.iresnet import iresnet100 |
|
|
from models.discriminator import Discriminator |
|
|
from models.gan_loss import GANLoss |
|
|
from models.generator import Generator |
|
|
from models.init_weight import init_net |
|
|
|
|
|
class HifiFace: |
|
|
def __init__( |
|
|
self, |
|
|
identity_extractor_config, |
|
|
generator_path, |
|
|
is_training=False, |
|
|
device="cpu" |
|
|
): |
|
|
super(HifiFace, self).__init__() |
|
|
self.d_optimizer = None |
|
|
self.g_optimizer = None |
|
|
self.generator = Generator(identity_extractor_config) |
|
|
self.is_training = is_training |
|
|
self.device = device |
|
|
self.generator_path = generator_path |
|
|
|
|
|
if self.is_training: |
|
|
self.lr = TrainConfig().lr |
|
|
self.use_ddp = TrainConfig().use_ddp |
|
|
self.grad_clip = TrainConfig().grad_clip if TrainConfig().grad_clip is not None else 100.0 |
|
|
|
|
|
self.discriminator = init_net(Discriminator(3)) |
|
|
|
|
|
self.l1_loss = nn.L1Loss() |
|
|
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
|
|
self.mse_loss = nn.MSELoss() |
|
|
self.loss_fn_vgg = lpips.LPIPS(net="vgg") |
|
|
self.adv_loss = GANLoss() |
|
|
|
|
|
|
|
|
self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) |
|
|
self.f_3d.load_state_dict( |
|
|
torch.load(identity_extractor_config["f_3d_checkpoint_path"], map_location="cpu")["net_recon"] |
|
|
) |
|
|
self.f_3d.eval() |
|
|
self.face_model = ParametricFaceModel(bfm_folder=identity_extractor_config["bfm_folder"]) |
|
|
self.face_model.to("cpu") |
|
|
|
|
|
|
|
|
self.f_id = iresnet100(pretrained=False, fp16=False) |
|
|
self.f_id.load_state_dict(torch.load(identity_extractor_config["f_id_checkpoint_path"], map_location="cpu")) |
|
|
self.f_id.eval() |
|
|
|
|
|
|
|
|
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
|
|
self.model_mouth = HighResolutionNet() |
|
|
checkpoint = torch.load(identity_extractor_config["hrnet_path"], map_location="cpu") |
|
|
self.model_mouth.load_state_dict(checkpoint) |
|
|
self.model_mouth.eval() |
|
|
|
|
|
self.lambda_adv = 1 |
|
|
self.lambda_seg = 100 |
|
|
self.lambda_rec = 20 |
|
|
self.lambda_cyc = 1 |
|
|
self.lambda_lpips = 5 |
|
|
|
|
|
self.lambda_shape = 0.5 |
|
|
self.lambda_id = 5 |
|
|
self.lambda_eye_hm = 10000.0 |
|
|
self.lambda_mouth_hm = 10000.0 |
|
|
|
|
|
self.dilation_kernel = torch.ones(5, 5) |
|
|
|
|
|
self.load_checkpoint() |
|
|
|
|
|
self.setup(self.device) |
|
|
|
|
|
def save(self, path, idx=None): |
|
|
os.makedirs(path, exist_ok=True) |
|
|
if idx is None: |
|
|
g_path = os.path.join(path, "generator.pth") |
|
|
d_path = os.path.join(path, "discriminator.pth") |
|
|
else: |
|
|
g_path = os.path.join(path, f"generator_{idx}.pth") |
|
|
d_path = os.path.join(path, f"discriminator_{idx}.pth") |
|
|
if self.use_ddp: |
|
|
torch.save(self.generator.module.state_dict(), g_path) |
|
|
torch.save(self.discriminator.module.state_dict(), d_path) |
|
|
else: |
|
|
torch.save(self.generator.state_dict(), g_path) |
|
|
torch.save(self.discriminator.state_dict(), d_path) |
|
|
|
|
|
@abstractmethod |
|
|
def load_checkpoint(self): |
|
|
pass |
|
|
|
|
|
def setup(self, device): |
|
|
self.generator.to(device) |
|
|
|
|
|
if self.is_training: |
|
|
self.discriminator.to(device) |
|
|
self.l1_loss.to(device) |
|
|
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
|
|
self.mse_loss.to(device) |
|
|
self.f_3d.to(device) |
|
|
self.f_id.to(device) |
|
|
|
|
|
self.loss_fn_vgg.to(device) |
|
|
self.face_model.to(device) |
|
|
self.adv_loss.to(device) |
|
|
|
|
|
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
|
|
self.model_mouth.to(device) |
|
|
self.f_3d.requires_grad_(False) |
|
|
self.f_id.requires_grad_(False) |
|
|
self.loss_fn_vgg.requires_grad_(False) |
|
|
if TrainConfig().mouth_hm_loss or TrainConfig().eye_hm_loss: |
|
|
self.model_mouth.requires_grad_(False) |
|
|
self.dilation_kernel = self.dilation_kernel.to(device) |
|
|
if self.use_ddp: |
|
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
import torch.distributed as dist |
|
|
|
|
|
self.generator = DDP(self.generator, device_ids=[device]) |
|
|
self.discriminator = DDP(self.discriminator, device_ids=[device]) |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
torch.save(self.generator.state_dict(), "/tmp/generator.pth") |
|
|
torch.save(self.discriminator.state_dict(), "/tmp/discriminator.pth") |
|
|
|
|
|
dist.barrier() |
|
|
self.generator.load_state_dict(torch.load("/tmp/generator.pth", map_location=device)) |
|
|
self.discriminator.load_state_dict(torch.load("/tmp/discriminator.pth", map_location=device)) |
|
|
|
|
|
self.g_optimizer = torch.optim.AdamW(self.generator.parameters(), lr=self.lr, betas=[0, 0.999]) |
|
|
self.d_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.lr, betas=[0, 0.999]) |
|
|
|
|
|
def train(self): |
|
|
self.generator.train() |
|
|
self.discriminator.train() |
|
|
|
|
|
if self.use_ddp: |
|
|
self.generator.module.id_extractor.eval() |
|
|
else: |
|
|
self.generator.id_extractor.eval() |
|
|
|
|
|
def eval(self): |
|
|
self.generator.eval() |
|
|
if self.is_training: |
|
|
self.discriminator.eval() |
|
|
|
|
|
def train_forward_generator(self, source_img, target_img, target_mask, same_id_mask): |
|
|
""" |
|
|
训练时候 Generator的loss计算 |
|
|
Parameters: |
|
|
----------- |
|
|
source_img: torch.Tensor |
|
|
target_img: torch.Tensor |
|
|
target_mask: torch.Tensor, [B, 1, H, W] |
|
|
same_id_mask: torch.Tensor, [B, 1] |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
source_img: torch.Tensor |
|
|
target_img: torch.Tensor |
|
|
i_cycle: torch.Tensor, cycle image |
|
|
i_r: torch.Tensor |
|
|
m_r: torch.Tensor |
|
|
loss: Dict[torch.Tensor], contain pairs of loss name and loss values |
|
|
""" |
|
|
same = same_id_mask.unsqueeze(-1).unsqueeze(-1) |
|
|
i_r, i_low, m_r, m_low = self.generator(source_img, target_img, need_id_grad=False) |
|
|
i_cycle, _, _, _ = self.generator(target_img, i_r, need_id_grad=True) |
|
|
d_r = self.discriminator(i_r) |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
c_s = self.f_3d(F.interpolate(source_img, size=224, mode="bilinear")) |
|
|
c_t = self.f_3d(F.interpolate(target_img, size=224, mode="bilinear")) |
|
|
c_r = self.f_3d(F.interpolate(i_r, size=224, mode="bilinear")) |
|
|
c_low = self.f_3d(F.interpolate(i_low, size=224, mode="bilinear")) |
|
|
with torch.no_grad(): |
|
|
c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) |
|
|
_, _, _, q_fuse = self.face_model.compute_for_render(c_fuse) |
|
|
_, _, _, q_r = self.face_model.compute_for_render(c_r) |
|
|
_, _, _, q_low = self.face_model.compute_for_render(c_low) |
|
|
with torch.no_grad(): |
|
|
v_id_i_s = F.normalize( |
|
|
self.f_id(F.interpolate((source_img - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2 |
|
|
) |
|
|
|
|
|
v_id_i_r = F.normalize(self.f_id(F.interpolate((i_r - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
v_id_i_low = F.normalize(self.f_id(F.interpolate((i_low - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
loss_shape = self.l1_loss(q_fuse, q_r) + self.l1_loss(q_fuse, q_low) |
|
|
loss_shape = torch.clamp(loss_shape, min=0.0, max=10.0) |
|
|
|
|
|
inner_product_r = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_r.unsqueeze(2)).squeeze() |
|
|
inner_product_low = torch.bmm(v_id_i_s.unsqueeze(1), v_id_i_low.unsqueeze(2)).squeeze() |
|
|
loss_id = self.l1_loss(torch.ones_like(inner_product_r), inner_product_r) + self.l1_loss( |
|
|
torch.ones_like(inner_product_low), inner_product_low |
|
|
) |
|
|
loss_sid = self.lambda_shape * loss_shape + self.lambda_id * loss_id |
|
|
|
|
|
|
|
|
|
|
|
loss_cycle = self.l1_loss(target_img, i_cycle) |
|
|
|
|
|
|
|
|
target_mask = kornia.morphology.dilation(target_mask, self.dilation_kernel) |
|
|
|
|
|
loss_segmentation = self.l1_loss( |
|
|
F.interpolate(target_mask, scale_factor=0.25, mode="bilinear"), m_low |
|
|
) + self.l1_loss(target_mask, m_r) |
|
|
|
|
|
loss_reconstruction = self.l1_loss(i_r * same, target_img * same) + self.l1_loss( |
|
|
i_low * same, F.interpolate(target_img, scale_factor=0.25, mode="bilinear") * same |
|
|
) |
|
|
|
|
|
loss_perceptual = self.loss_fn_vgg(target_img * same, i_r * same).mean() |
|
|
|
|
|
loss_adversarial = self.adv_loss(d_r, True, for_discriminator=False) |
|
|
|
|
|
loss_realism = ( |
|
|
self.lambda_adv * loss_adversarial |
|
|
+ self.lambda_seg * loss_segmentation |
|
|
+ self.lambda_rec * loss_reconstruction |
|
|
+ self.lambda_cyc * loss_cycle |
|
|
+ self.lambda_lpips * loss_perceptual |
|
|
) |
|
|
|
|
|
|
|
|
loss_eye_hm = 0 |
|
|
|
|
|
loss_mouth_hm = 0 |
|
|
if TrainConfig().eye_hm_loss or TrainConfig().mouth_hm_loss: |
|
|
target_hm = self.model_mouth(target_img) |
|
|
r_hm = self.model_mouth(i_r) |
|
|
|
|
|
if TrainConfig().eye_hm_loss: |
|
|
target_eye_hm = target_hm[:, 96:98, :, :] |
|
|
r_eye_hm = r_hm[:, 96:98, :, :] |
|
|
loss_eye_hm = self.mse_loss(r_eye_hm, target_eye_hm) |
|
|
loss_realism = loss_realism + self.lambda_eye_hm * loss_eye_hm |
|
|
|
|
|
if TrainConfig().mouth_hm_loss: |
|
|
target_mouth_hm = target_hm[:, 76:96, :, :] |
|
|
r_mouth_hm = r_hm[:, 76:96, :, :] |
|
|
loss_mouth_hm = self.mse_loss(r_mouth_hm, target_mouth_hm) |
|
|
loss_realism = loss_realism + self.lambda_mouth_hm * loss_mouth_hm |
|
|
|
|
|
loss_generator = loss_sid + loss_realism |
|
|
|
|
|
loss_dict = { |
|
|
"loss_shape": loss_shape, |
|
|
"loss_id": loss_id, |
|
|
"loss_sid": loss_sid, |
|
|
"loss_cycle": loss_cycle, |
|
|
"loss_segmentation": loss_segmentation, |
|
|
"loss_reconstruction": loss_reconstruction, |
|
|
"loss_perceptual": loss_perceptual, |
|
|
"loss_adversarial": loss_adversarial, |
|
|
"loss_realism": loss_realism, |
|
|
"loss_generator": loss_generator, |
|
|
} |
|
|
if TrainConfig().eye_hm_loss: |
|
|
loss_dict.update({"loss_eye_hm": loss_eye_hm}) |
|
|
if TrainConfig().mouth_hm_loss: |
|
|
loss_dict.update({"loss_mouth_hm": loss_mouth_hm}) |
|
|
return ( |
|
|
source_img, |
|
|
target_img, |
|
|
i_cycle.detach(), |
|
|
i_r.detach(), |
|
|
m_r.detach(), |
|
|
loss_dict, |
|
|
) |
|
|
|
|
|
def train_forward_discriminator(self, target_img, i_r): |
|
|
""" |
|
|
训练时候 Discriminator的loss计算 |
|
|
Parameters: |
|
|
----------- |
|
|
target_img: torch.Tensor, 目标脸图片 |
|
|
i_r: torch.Tensor, 换脸结果 |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
Dict[str]: contains pair of loss name and loss values |
|
|
""" |
|
|
d_gt = self.discriminator(target_img) |
|
|
d_fake = self.discriminator(i_r.detach()) |
|
|
loss_real = self.adv_loss(d_gt, True) |
|
|
loss_fake = self.adv_loss(d_fake, False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_discriminator = loss_real + loss_fake |
|
|
return { |
|
|
"loss_real": loss_real, |
|
|
"loss_fake": loss_fake, |
|
|
|
|
|
"loss_discriminator": loss_discriminator, |
|
|
} |
|
|
|
|
|
def forward( |
|
|
self, source_img: torch.Tensor, target_img: torch.Tensor, shape_rate=None, id_rate=None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Parameters: |
|
|
----------- |
|
|
source_img: torch.Tensor, source face 图像 |
|
|
target_img: torch.Tensor, target face 图像 |
|
|
*_rate: 插值系数 |
|
|
Returns: |
|
|
-------- |
|
|
i_r: torch.Tensor, swapped result |
|
|
""" |
|
|
if shape_rate is None and id_rate is None: |
|
|
i_r, _, m_r, _ = self.generator(source_img, target_img) |
|
|
else: |
|
|
if shape_rate is None: |
|
|
shape_rate = 1.0 |
|
|
if id_rate is None: |
|
|
id_rate = 1.0 |
|
|
i_r, _, m_r, _ = self.generator.interp(source_img, target_img, shape_rate, id_rate) |
|
|
return i_r, m_r |
|
|
|
|
|
def optimize( |
|
|
self, |
|
|
source_img: torch.Tensor, |
|
|
target_img: torch.Tensor, |
|
|
target_mask: torch.Tensor, |
|
|
same_id_mask: torch.Tensor, |
|
|
) -> Tuple[Dict, Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
模型的optimize |
|
|
训练模式下执行一次训练,并返回loss信息和结果 |
|
|
Parameters: |
|
|
----------- |
|
|
source_img: torch.Tensor, source face 图像 |
|
|
target_img: torch.Tensor, target face 图像 |
|
|
target_mask: torch.Tensor, target face mask |
|
|
same_id_mask: torch.Tensor, same id mask, 标识source 和 target是否是同个人 |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
Tuple[Dict, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
loss_dict, source_img, target_img, m_r(预测的mask), i_r(换脸结果) |
|
|
""" |
|
|
src_img, tgt_img, i_cycle, i_r, m_r, loss_G_dict = self.train_forward_generator( |
|
|
source_img, target_img, target_mask, same_id_mask |
|
|
) |
|
|
loss_G = loss_G_dict["loss_generator"] |
|
|
self.g_optimizer.zero_grad() |
|
|
loss_G.backward() |
|
|
global_norm_G = torch.nn.utils.clip_grad_norm_(self.generator.parameters(), self.grad_clip) |
|
|
self.g_optimizer.step() |
|
|
|
|
|
loss_D_dict = self.train_forward_discriminator(tgt_img, i_r) |
|
|
loss_D = loss_D_dict["loss_discriminator"] |
|
|
self.d_optimizer.zero_grad() |
|
|
loss_D.backward() |
|
|
global_norm_D = torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), self.grad_clip) |
|
|
self.d_optimizer.step() |
|
|
|
|
|
total_loss_dict = {"global_norm_G": global_norm_G, "global_norm_D": global_norm_D} |
|
|
total_loss_dict.update(loss_G_dict) |
|
|
total_loss_dict.update(loss_D_dict) |
|
|
|
|
|
return total_loss_dict, { |
|
|
"source face": src_img, |
|
|
"target face": tgt_img, |
|
|
"swapped face": torch.clamp(i_r, min=0.0, max=1.0), |
|
|
"pred face mask": m_r, |
|
|
"cycle face": i_cycle, |
|
|
} |
|
|
|
|
|
|
|
|
class HifiFaceST(HifiFace): |
|
|
def __init__(self, identity_extractor_config, device, generator_path): |
|
|
super().__init__(identity_extractor_config, device=device, generator_path=generator_path) |
|
|
|
|
|
def load_checkpoint(self): |
|
|
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device)) |
|
|
logger.info(f"Loading generator from {self.generator_path}") |
|
|
|
|
|
class HifiFaceWGM(HifiFace): |
|
|
def __init__(self, identity_extractor_config, device, generator_path): |
|
|
super().__init__(identity_extractor_config, device=device, generator_path=generator_path) |
|
|
|
|
|
def load_checkpoint(self): |
|
|
self.generator.load_state_dict(torch.load(self.generator_path, map_location=self.device)) |
|
|
logger.info(f"Loading generator from {self.generator_path}") |