|
|
import torch
|
|
|
import yaml
|
|
|
import os
|
|
|
|
|
|
import safetensors
|
|
|
from safetensors.torch import save_file
|
|
|
from yacs.config import CfgNode as CN
|
|
|
import sys
|
|
|
|
|
|
sys.path.append('/apdcephfs/private_shadowcun/SadTalker')
|
|
|
|
|
|
from src.face3d.models import networks
|
|
|
|
|
|
from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
|
|
|
from src.facerender.modules.mapping import MappingNet
|
|
|
from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
|
|
|
|
|
|
from src.audio2pose_models.audio2pose import Audio2Pose
|
|
|
from src.audio2exp_models.networks import SimpleWrapperV2
|
|
|
from src.test_audio2coeff import load_cpk
|
|
|
|
|
|
size = 256
|
|
|
|
|
|
config_path = os.path.join('src', 'config', 'facerender.yaml')
|
|
|
current_root_path = '.'
|
|
|
|
|
|
path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
|
|
|
net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='')
|
|
|
checkpoint = torch.load(path_of_net_recon_model, map_location='cpu')
|
|
|
net_recon.load_state_dict(checkpoint['net_recon'])
|
|
|
|
|
|
with open(config_path) as f:
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
|
|
|
**config['model_params']['common_params'])
|
|
|
kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
|
|
|
**config['model_params']['common_params'])
|
|
|
he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
|
|
|
**config['model_params']['common_params'])
|
|
|
mapping = MappingNet(**config['model_params']['mapping_params'])
|
|
|
|
|
|
def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None,
|
|
|
kp_detector=None, he_estimator=None, optimizer_generator=None,
|
|
|
optimizer_discriminator=None, optimizer_kp_detector=None,
|
|
|
optimizer_he_estimator=None, device="cpu"):
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
|
|
if generator is not None:
|
|
|
generator.load_state_dict(checkpoint['generator'])
|
|
|
if kp_detector is not None:
|
|
|
kp_detector.load_state_dict(checkpoint['kp_detector'])
|
|
|
if he_estimator is not None:
|
|
|
he_estimator.load_state_dict(checkpoint['he_estimator'])
|
|
|
if discriminator is not None:
|
|
|
try:
|
|
|
discriminator.load_state_dict(checkpoint['discriminator'])
|
|
|
except:
|
|
|
print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
|
|
|
if optimizer_generator is not None:
|
|
|
optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
|
|
|
if optimizer_discriminator is not None:
|
|
|
try:
|
|
|
optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
|
|
|
except RuntimeError as e:
|
|
|
print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
|
|
|
if optimizer_kp_detector is not None:
|
|
|
optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
|
|
|
if optimizer_he_estimator is not None:
|
|
|
optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
|
|
|
|
|
|
return checkpoint['epoch']
|
|
|
|
|
|
|
|
|
def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None,
|
|
|
kp_detector=None, he_estimator=None,
|
|
|
device="cpu"):
|
|
|
|
|
|
checkpoint = safetensors.torch.load_file(checkpoint_path)
|
|
|
|
|
|
if generator is not None:
|
|
|
x_generator = {}
|
|
|
for k,v in checkpoint.items():
|
|
|
if 'generator' in k:
|
|
|
x_generator[k.replace('generator.', '')] = v
|
|
|
generator.load_state_dict(x_generator)
|
|
|
if kp_detector is not None:
|
|
|
x_generator = {}
|
|
|
for k,v in checkpoint.items():
|
|
|
if 'kp_extractor' in k:
|
|
|
x_generator[k.replace('kp_extractor.', '')] = v
|
|
|
kp_detector.load_state_dict(x_generator)
|
|
|
if he_estimator is not None:
|
|
|
x_generator = {}
|
|
|
for k,v in checkpoint.items():
|
|
|
if 'he_estimator' in k:
|
|
|
x_generator[k.replace('he_estimator.', '')] = v
|
|
|
he_estimator.load_state_dict(x_generator)
|
|
|
|
|
|
return None
|
|
|
|
|
|
free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar'
|
|
|
load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
|
|
|
|
|
|
wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
|
|
|
|
|
|
audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
|
|
|
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
|
|
|
|
|
|
audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
|
|
|
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
|
|
|
|
|
|
fcfg_pose = open(audio2pose_yaml_path)
|
|
|
cfg_pose = CN.load_cfg(fcfg_pose)
|
|
|
cfg_pose.freeze()
|
|
|
audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint)
|
|
|
audio2pose_model.eval()
|
|
|
load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu')
|
|
|
|
|
|
|
|
|
netG = SimpleWrapperV2()
|
|
|
netG.eval()
|
|
|
load_cpk(audio2exp_checkpoint, model=netG, device='cpu')
|
|
|
|
|
|
class SadTalker(torch.nn.Module):
|
|
|
def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon):
|
|
|
super(SadTalker, self).__init__()
|
|
|
self.kp_extractor = kp_extractor
|
|
|
self.generator = generator
|
|
|
self.audio2exp = netG
|
|
|
self.audio2pose = audio2pose
|
|
|
self.face_3drecon = face_3drecon
|
|
|
|
|
|
|
|
|
model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon)
|
|
|
|
|
|
|
|
|
save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors")
|
|
|
|
|
|
|
|
|
load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) |