Spaces:
Running
Running
| import os | |
| import cv2 | |
| import yaml | |
| import tarfile | |
| import tempfile | |
| import numpy as np | |
| import warnings | |
| from skimage import img_as_ubyte | |
| import safetensors | |
| import safetensors.torch | |
| warnings.filterwarnings('ignore') | |
| import imageio | |
| import torch | |
| import torchvision | |
| from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector | |
| from src.facerender.modules.mapping import MappingNet | |
| from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator | |
| from src.facerender.modules.make_animation import make_animation | |
| from pydub import AudioSegment | |
| from src.utils.face_enhancer import enhancer_generator_with_len, enhancer_list | |
| from src.utils.paste_pic import paste_pic | |
| from src.utils.videoio import save_video_with_watermark | |
| try: | |
| import webui # in webui | |
| in_webui = True | |
| except ImportError: | |
| in_webui = False | |
| class AnimateFromCoeff: | |
| def __init__(self, sadtalker_path, device): | |
| with open(sadtalker_path['facerender_yaml']) as f: | |
| config = yaml.safe_load(f) | |
| generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'], | |
| **config['model_params']['common_params']) | |
| kp_extractor = KPDetector(**config['model_params']['kp_detector_params'], | |
| **config['model_params']['common_params']) | |
| he_estimator = HEEstimator(**config['model_params']['he_estimator_params'], | |
| **config['model_params']['common_params']) | |
| mapping = MappingNet(**config['model_params']['mapping_params']) | |
| generator.to(device) | |
| kp_extractor.to(device) | |
| he_estimator.to(device) | |
| mapping.to(device) | |
| for param in generator.parameters(): | |
| param.requires_grad = False | |
| for param in kp_extractor.parameters(): | |
| param.requires_grad = False | |
| for param in he_estimator.parameters(): | |
| param.requires_grad = False | |
| for param in mapping.parameters(): | |
| param.requires_grad = False | |
| # FaceVid2Vid checkpoint yükleme | |
| if 'checkpoint' in sadtalker_path: | |
| self.load_cpk_facevid2vid_safetensor( | |
| sadtalker_path['checkpoint'], | |
| kp_detector=kp_extractor, | |
| generator=generator, | |
| he_estimator=None, | |
| device=device | |
| ) | |
| else: | |
| self.load_cpk_facevid2vid( | |
| sadtalker_path['free_view_checkpoint'], | |
| kp_detector=kp_extractor, | |
| generator=generator, | |
| he_estimator=he_estimator, | |
| device=device | |
| ) | |
| # MappingNet checkpoint yükleme | |
| if sadtalker_path.get('mappingnet_checkpoint') is not None: | |
| self.load_cpk_mapping( | |
| sadtalker_path['mappingnet_checkpoint'], | |
| mapping=mapping, | |
| device=device | |
| ) | |
| else: | |
| raise AttributeError("mappingnet_checkpoint path belirtmelisiniz.") | |
| self.kp_extractor = kp_extractor | |
| self.generator = generator | |
| self.he_estimator = he_estimator | |
| self.mapping = mapping | |
| self.device = device | |
| self.kp_extractor.eval() | |
| self.generator.eval() | |
| self.he_estimator.eval() | |
| self.mapping.eval() | |
| def load_cpk_facevid2vid_safetensor(self, checkpoint_path, | |
| generator=None, kp_detector=None, | |
| he_estimator=None, device="cpu"): | |
| checkpoint = safetensors.torch.load_file(checkpoint_path) | |
| if generator is not None: | |
| state = {k.replace('generator.', ''): v | |
| for k, v in checkpoint.items() if k.startswith('generator.')} | |
| generator.load_state_dict(state) | |
| if kp_detector is not None: | |
| state = {k.replace('kp_extractor.', ''): v | |
| for k, v in checkpoint.items() if k.startswith('kp_extractor.')} | |
| kp_detector.load_state_dict(state) | |
| if he_estimator is not None: | |
| state = {k.replace('he_estimator.', ''): v | |
| for k, v in checkpoint.items() if k.startswith('he_estimator.')} | |
| he_estimator.load_state_dict(state) | |
| return None | |
| def load_cpk_facevid2vid(self, checkpoint_path, | |
| generator=None, discriminator=None, | |
| kp_detector=None, he_estimator=None, | |
| optimizer_generator=None, optimizer_discriminator=None, | |
| optimizer_kp_detector=None, optimizer_he_estimator=None, | |
| device="cpu"): | |
| checkpoint = torch.load(checkpoint_path, map_location=torch.device(device)) | |
| if generator is not None: | |
| generator.load_state_dict(checkpoint['generator']) | |
| if kp_detector is not None: | |
| kp_detector.load_state_dict(checkpoint['kp_detector']) | |
| if he_estimator is not None: | |
| he_estimator.load_state_dict(checkpoint['he_estimator']) | |
| if discriminator is not None and 'discriminator' in checkpoint: | |
| discriminator.load_state_dict(checkpoint['discriminator']) | |
| # Optimizeler varsa yükle | |
| if optimizer_generator is not None and 'optimizer_generator' in checkpoint: | |
| optimizer_generator.load_state_dict(checkpoint['optimizer_generator']) | |
| if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: | |
| optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) | |
| if optimizer_kp_detector is not None and 'optimizer_kp_detector' in checkpoint: | |
| optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector']) | |
| if optimizer_he_estimator is not None and 'optimizer_he_estimator' in checkpoint: | |
| optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator']) | |
| return checkpoint.get('epoch', 0) | |
| def load_cpk_mapping(self, checkpoint_path, | |
| mapping=None, discriminator=None, | |
| optimizer_mapping=None, optimizer_discriminator=None, | |
| device='cpu'): | |
| def load_cpk_mapping(self, | |
| checkpoint_path, | |
| mapping=None, | |
| discriminator=None, | |
| optimizer_mapping=None, | |
| optimizer_discriminator=None, | |
| device='cpu'): | |
| # 1) Eğer .tar veya .pth.tar ile bitiyorsa: | |
| if checkpoint_path.endswith('.tar') or checkpoint_path.endswith('.pth.tar'): | |
| tmpdir = tempfile.mkdtemp() | |
| with tarfile.open(checkpoint_path, 'r') as tar: | |
| tar.extractall(path=tmpdir) | |
| # 1.a) Önce .pth arıyoruz, bulamazsak .pkl | |
| candidate_pth = None | |
| candidate_pkl = None | |
| for root, _, files in os.walk(tmpdir): | |
| for f in files: | |
| if f.endswith('.pth') and candidate_pth is None: | |
| candidate_pth = os.path.join(root, f) | |
| if f.endswith('.pkl') and candidate_pkl is None: | |
| candidate_pkl = os.path.join(root, f) | |
| if candidate_pth: | |
| break | |
| if candidate_pth: | |
| checkpoint_path = candidate_pth | |
| elif candidate_pkl: | |
| checkpoint_path = candidate_pkl | |
| else: | |
| raise FileNotFoundError( | |
| f"{checkpoint_path} içinden ne .pth ne de .pkl dosyası bulunabildi." | |
| ) | |
| # 2) Eğer checkpoint_path bir klasörse, archive/data.pkl’e bak | |
| if os.path.isdir(checkpoint_path): | |
| possible = os.path.join(checkpoint_path, 'archive', 'data.pkl') | |
| if os.path.isfile(possible): | |
| checkpoint_path = possible | |
| # 3) Torch ile gerçek dosyayı yükle | |
| checkpoint = torch.load(checkpoint_path, | |
| map_location=torch.device(device)) | |
| # 4) State dict’leri ilgili modellere ata | |
| if mapping is not None and 'mapping' in checkpoint: | |
| mapping.load_state_dict(checkpoint['mapping']) | |
| if discriminator is not None and 'discriminator' in checkpoint: | |
| discriminator.load_state_dict(checkpoint['discriminator']) | |
| if optimizer_mapping is not None and 'optimizer_mapping' in checkpoint: | |
| optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping']) | |
| if optimizer_discriminator is not None and 'optimizer_discriminator' in checkpoint: | |
| optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator']) | |
| # 5) Epoch bilgisi varsa dön, yoksa 0 | |
| return checkpoint.get('epoch', 0) | |