|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from arcface_torch.backbones.iresnet import iresnet100 |
|
|
from Deep3DFaceRecon_pytorch.models.networks import ReconNetWrapper |
|
|
|
|
|
|
|
|
class ShapeAwareIdentityExtractor(nn.Module): |
|
|
def __init__(self, identity_extractor_config): |
|
|
""" |
|
|
Shape Aware Identity Extractor |
|
|
Parameters: |
|
|
---------- |
|
|
identity_extractor_config: Dict[str, str] |
|
|
必须包含以下内容: |
|
|
f_3d_checkpoint_path: str |
|
|
3D人脸重建模型路径,如"model/Deep3DFaceRecon_pytorch/checkpoints/epoch_20.pth" |
|
|
f_id_checkpoint_path: str |
|
|
arcface人脸识别模型路径 |
|
|
非官方实现用的是https://onedrive.live.com/?authkey=%21AFZjr283nwZHqbA&id=4A83B6B633B029CC%215585&cid=4A83B6B633B029CC/backbone.pth |
|
|
""" |
|
|
super(ShapeAwareIdentityExtractor, self).__init__() |
|
|
f_3d_checkpoint_path = identity_extractor_config["f_3d_checkpoint_path"] |
|
|
f_id_checkpoint_path = identity_extractor_config["f_id_checkpoint_path"] |
|
|
|
|
|
self.f_3d = ReconNetWrapper(net_recon="resnet50", use_last_fc=False) |
|
|
self.f_3d.load_state_dict(torch.load(f_3d_checkpoint_path, map_location="cpu")["net_recon"]) |
|
|
self.f_3d.eval() |
|
|
|
|
|
|
|
|
self.f_id = iresnet100(pretrained=False, fp16=False) |
|
|
self.f_id.load_state_dict(torch.load(f_id_checkpoint_path, map_location="cpu")) |
|
|
self.f_id.eval() |
|
|
|
|
|
@torch.no_grad() |
|
|
def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): |
|
|
""" |
|
|
插值shape和id信息 |
|
|
""" |
|
|
c_s = self.f_3d(i_source) |
|
|
c_t = self.f_3d(i_target) |
|
|
c_interp = shape_rate * c_s + (1 - shape_rate) * c_t |
|
|
c_fuse = torch.cat((c_interp[:, :80], c_t[:, 80:]), dim=1) |
|
|
|
|
|
v_s = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
v_t = F.normalize(self.f_id(F.interpolate((i_target - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
v_id = id_rate * v_s + (1 - id_rate) * v_t |
|
|
|
|
|
v_sid = torch.cat((c_fuse, v_id), dim=1) |
|
|
return v_sid |
|
|
|
|
|
def forward(self, i_source, i_target): |
|
|
""" |
|
|
Parameters: |
|
|
----------- |
|
|
i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image |
|
|
i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image |
|
|
|
|
|
Returns: |
|
|
-------- |
|
|
v_sid: torch.Tensor, fused shape and id features |
|
|
""" |
|
|
|
|
|
c_s = self.f_3d(i_source) |
|
|
c_t = self.f_3d(i_target) |
|
|
|
|
|
|
|
|
|
|
|
c_fuse = torch.cat((c_s[:, :80], c_t[:, 80:]), dim=1) |
|
|
|
|
|
|
|
|
v_id = F.normalize(self.f_id(F.interpolate((i_source - 0.5) / 0.5, size=112, mode="bicubic")), dim=-1, p=2) |
|
|
|
|
|
|
|
|
v_sid = torch.cat((c_fuse, v_id), dim=1) |
|
|
return v_sid |
|
|
|