| | import torch |
| | import yaml |
| | import os |
| |
|
| | import safetensors |
| | from safetensors.torch import save_file |
| | from yacs.config import CfgNode as CN |
| | import sys |
| |
|
| | sys.path.append('/apdcephfs/private_shadowcun/SadTalker') |
| |
|
| | from src.face3d.models import networks |
| |
|
| | from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector |
| | from src.facerender.modules.mapping import MappingNet |
| | from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator |
| |
|
| | from src.audio2pose_models.audio2pose import Audio2Pose |
| | from src.audio2exp_models.networks import SimpleWrapperV2 |
| | from src.test_audio2coeff import load_cpk |
| |
|
| | size = 256 |
| | |
| | config_path = os.path.join('src', 'config', 'facerender.yaml') |
| | current_root_path = '.' |
| |
|
| | path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth') |
| | net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='') |
| | checkpoint = torch.load(path_of_net_recon_model, map_location='cpu') |
| | net_recon.load_state_dict(checkpoint['net_recon']) |
| |
|
| | with open(config_path) as f: |
| | config = yaml.safe_load(f) |
| |
|
| | generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], |
| | **config['model_params']['common_params']) |
| | kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], |
| | **config['model_params']['common_params']) |
| | he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], |
| | **config['model_params']['common_params']) |
| | mapping = MappingNet(**config['model_params']['mapping_params']) |
| |
|
| | def load_cpk_facevid2vid(checkpoint_path, generator=None, discriminator=None, |
| | kp_detector=None, he_estimator=None, optimizer_generator=None, |
| | optimizer_discriminator=None, optimizer_kp_detector=None, |
| | optimizer_he_estimator=None, device="cpu"): |
| |
|
| | checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) |
| | if generator is not None: |
| | generator.load_state_dict(checkpoint['generator']) |
| | if kp_detector is not None: |
| | kp_detector.load_state_dict(checkpoint['kp_detector']) |
| | if he_estimator is not None: |
| | he_estimator.load_state_dict(checkpoint['he_estimator']) |
| | if discriminator is not None: |
| | try: |
| | discriminator.load_state_dict(checkpoint['discriminator']) |
| | except: |
| | print ('No discriminator in the state-dict. Dicriminator will be randomly initialized') |
| | if optimizer_generator is not None: |
| | optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) |
| | if optimizer_discriminator is not None: |
| | try: |
| | optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) |
| | except RuntimeError as e: |
| | print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized') |
| | if optimizer_kp_detector is not None: |
| | optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) |
| | if optimizer_he_estimator is not None: |
| | optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) |
| |
|
| | return checkpoint['epoch'] |
| |
|
| |
|
| | def load_cpk_facevid2vid_safetensor(checkpoint_path, generator=None, |
| | kp_detector=None, he_estimator=None, |
| | device="cpu"): |
| |
|
| | checkpoint = safetensors.torch.load_file(checkpoint_path) |
| |
|
| | if generator is not None: |
| | x_generator = {} |
| | for k,v in checkpoint.items(): |
| | if 'generator' in k: |
| | x_generator[k.replace('generator.', '')] = v |
| | generator.load_state_dict(x_generator) |
| | if kp_detector is not None: |
| | x_generator = {} |
| | for k,v in checkpoint.items(): |
| | if 'kp_extractor' in k: |
| | x_generator[k.replace('kp_extractor.', '')] = v |
| | kp_detector.load_state_dict(x_generator) |
| | if he_estimator is not None: |
| | x_generator = {} |
| | for k,v in checkpoint.items(): |
| | if 'he_estimator' in k: |
| | x_generator[k.replace('he_estimator.', '')] = v |
| | he_estimator.load_state_dict(x_generator) |
| | |
| | return None |
| |
|
| | free_view_checkpoint = '/apdcephfs/private_shadowcun/SadTalker/checkpoints/facevid2vid_'+str(size)+'-model.pth.tar' |
| | load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator) |
| |
|
| | wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth') |
| |
|
| | audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth') |
| | audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml') |
| |
|
| | audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth') |
| | audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml') |
| |
|
| | fcfg_pose = open(audio2pose_yaml_path) |
| | cfg_pose = CN.load_cfg(fcfg_pose) |
| | cfg_pose.freeze() |
| | audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint) |
| | audio2pose_model.eval() |
| | load_cpk(audio2pose_checkpoint, model=audio2pose_model, device='cpu') |
| |
|
| | |
| | netG = SimpleWrapperV2() |
| | netG.eval() |
| | load_cpk(audio2exp_checkpoint, model=netG, device='cpu') |
| |
|
| | class SadTalker(torch.nn.Module): |
| | def __init__(self, kp_extractor, generator, netG, audio2pose, face_3drecon): |
| | super(SadTalker, self).__init__() |
| | self.kp_extractor = kp_extractor |
| | self.generator = generator |
| | self.audio2exp = netG |
| | self.audio2pose = audio2pose |
| | self.face_3drecon = face_3drecon |
| |
|
| |
|
| | model = SadTalker(kp_extractor, generator, netG, audio2pose_model, net_recon) |
| |
|
| | |
| | save_file(model.state_dict(), "checkpoints/SadTalker_V0.0.2_"+str(size)+".safetensors") |
| |
|
| | |
| | load_cpk_facevid2vid_safetensor('checkpoints/SadTalker_V0.0.2_'+str(size)+'.safetensors', kp_detector=kp_extractor, generator=generator, he_estimator=None) |