jhj0517
commited on
Commit
·
6ecdb23
1
Parent(s):
dffd539
Add progress during model loading
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -56,7 +56,8 @@ class LivePortraitInferencer:
|
|
| 56 |
self.psi_list = None
|
| 57 |
self.d_info = None
|
| 58 |
|
| 59 |
-
def load_models(self
|
|
|
|
| 60 |
def filter_stitcher(checkpoint, prefix):
|
| 61 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 62 |
key.startswith(prefix)}
|
|
@@ -64,6 +65,8 @@ class LivePortraitInferencer:
|
|
| 64 |
|
| 65 |
self.download_if_no_models()
|
| 66 |
|
|
|
|
|
|
|
| 67 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
| 68 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
| 69 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
|
@@ -71,6 +74,7 @@ class LivePortraitInferencer:
|
|
| 71 |
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
| 72 |
)
|
| 73 |
|
|
|
|
| 74 |
motion_ext_config = self.model_config["motion_extractor_params"]
|
| 75 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
| 76 |
self.motion_extractor = self.load_safe_tensor(
|
|
@@ -78,6 +82,7 @@ class LivePortraitInferencer:
|
|
| 78 |
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
| 79 |
)
|
| 80 |
|
|
|
|
| 81 |
warping_module_config = self.model_config["warping_module_params"]
|
| 82 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
| 83 |
self.warping_module = self.load_safe_tensor(
|
|
@@ -85,6 +90,7 @@ class LivePortraitInferencer:
|
|
| 85 |
os.path.join(self.model_dir, "warping_module.safetensors")
|
| 86 |
)
|
| 87 |
|
|
|
|
| 88 |
spaded_decoder_config = self.model_config["spade_generator_params"]
|
| 89 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
| 90 |
self.spade_generator = self.load_safe_tensor(
|
|
@@ -92,6 +98,7 @@ class LivePortraitInferencer:
|
|
| 92 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
| 93 |
)
|
| 94 |
|
|
|
|
| 95 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 96 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 97 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|
|
|
|
| 56 |
self.psi_list = None
|
| 57 |
self.d_info = None
|
| 58 |
|
| 59 |
+
def load_models(self,
|
| 60 |
+
progress=gr.Progress()):
|
| 61 |
def filter_stitcher(checkpoint, prefix):
|
| 62 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 63 |
key.startswith(prefix)}
|
|
|
|
| 65 |
|
| 66 |
self.download_if_no_models()
|
| 67 |
|
| 68 |
+
total_models_num = 5
|
| 69 |
+
progress(0/total_models_num, desc="Loading Appearance Feature Extractor model...")
|
| 70 |
appearance_feat_config = self.model_config["appearance_feature_extractor_params"]
|
| 71 |
self.appearance_feature_extractor = AppearanceFeatureExtractor(**appearance_feat_config).to(self.device)
|
| 72 |
self.appearance_feature_extractor = self.load_safe_tensor(
|
|
|
|
| 74 |
os.path.join(self.model_dir, "appearance_feature_extractor.safetensors")
|
| 75 |
)
|
| 76 |
|
| 77 |
+
progress(1/total_models_num, desc="Loading Motion Extractor model...")
|
| 78 |
motion_ext_config = self.model_config["motion_extractor_params"]
|
| 79 |
self.motion_extractor = MotionExtractor(**motion_ext_config).to(self.device)
|
| 80 |
self.motion_extractor = self.load_safe_tensor(
|
|
|
|
| 82 |
os.path.join(self.model_dir, "motion_extractor.safetensors")
|
| 83 |
)
|
| 84 |
|
| 85 |
+
progress(2/total_models_num, desc="Loading Warping Module model...")
|
| 86 |
warping_module_config = self.model_config["warping_module_params"]
|
| 87 |
self.warping_module = WarpingNetwork(**warping_module_config).to(self.device)
|
| 88 |
self.warping_module = self.load_safe_tensor(
|
|
|
|
| 90 |
os.path.join(self.model_dir, "warping_module.safetensors")
|
| 91 |
)
|
| 92 |
|
| 93 |
+
progress(3/total_models_num, desc="Loading Spade generator model...")
|
| 94 |
spaded_decoder_config = self.model_config["spade_generator_params"]
|
| 95 |
self.spade_generator = SPADEDecoder(**spaded_decoder_config).to(self.device)
|
| 96 |
self.spade_generator = self.load_safe_tensor(
|
|
|
|
| 98 |
os.path.join(self.model_dir, "spade_generator.safetensors")
|
| 99 |
)
|
| 100 |
|
| 101 |
+
progress(4/total_models_num, desc="Loading Stitcher model...")
|
| 102 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 103 |
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 104 |
stitcher_model_path = os.path.join(self.model_dir, "stitching_retargeting_module.safetensors")
|