jhj0517
commited on
Commit
·
3b34b71
1
Parent(s):
3f75a4f
Refactor model loading
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -58,11 +58,6 @@ class LivePortraitInferencer:
|
|
| 58 |
|
| 59 |
def load_models(self,
|
| 60 |
progress=gr.Progress()):
|
| 61 |
-
def filter_stitcher(checkpoint, prefix):
|
| 62 |
-
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 63 |
-
key.startswith(prefix)}
|
| 64 |
-
return filtered_checkpoint
|
| 65 |
-
|
| 66 |
self.download_if_no_models()
|
| 67 |
|
| 68 |
total_models_num = 5
|
|
@@ -100,11 +95,12 @@ class LivePortraitInferencer:
|
|
| 100 |
|
| 101 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
| 102 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 103 |
-
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching'))
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
| 109 |
|
| 110 |
if self.pipeline is None:
|
|
@@ -350,8 +346,16 @@ class LivePortraitInferencer:
|
|
| 350 |
download_model(model_path, model_url)
|
| 351 |
|
| 352 |
@staticmethod
|
| 353 |
-
def load_safe_tensor(model, file_path):
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
model.eval()
|
| 356 |
return model
|
| 357 |
|
|
|
|
| 58 |
|
| 59 |
def load_models(self,
|
| 60 |
progress=gr.Progress()):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
self.download_if_no_models()
|
| 62 |
|
| 63 |
total_models_num = 5
|
|
|
|
| 95 |
|
| 96 |
progress(4/total_models_num, desc="Loading Stitcher model...")
|
| 97 |
stitcher_config = self.model_config["stitching_retargeting_module_params"]
|
| 98 |
+
self.stitching_retargeting_module = StitchingRetargetingNetwork(**stitcher_config.get('stitching')).to(self.device)
|
| 99 |
+
self.stitching_retargeting_module = self.load_safe_tensor(
|
| 100 |
+
self.stitching_retargeting_module,
|
| 101 |
+
os.path.join(self.model_dir, "stitching_retargeting_module.safetensors"),
|
| 102 |
+
True
|
| 103 |
+
)
|
| 104 |
self.stitching_retargeting_module = {"stitching": self.stitching_retargeting_module}
|
| 105 |
|
| 106 |
if self.pipeline is None:
|
|
|
|
| 346 |
download_model(model_path, model_url)
|
| 347 |
|
| 348 |
@staticmethod
|
| 349 |
+
def load_safe_tensor(model, file_path, is_stitcher=False):
|
| 350 |
+
def filter_stitcher(checkpoint, prefix):
|
| 351 |
+
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
| 352 |
+
key.startswith(prefix)}
|
| 353 |
+
return filtered_checkpoint
|
| 354 |
+
|
| 355 |
+
if is_stitcher:
|
| 356 |
+
model.load_state_dict(filter_stitcher(safetensors.torch.load_file(file_path), 'retarget_shoulder'))
|
| 357 |
+
else:
|
| 358 |
+
model.load_state_dict(safetensors.torch.load_file(file_path))
|
| 359 |
model.eval()
|
| 360 |
return model
|
| 361 |
|