jhj0517
commited on
Commit
·
7f7cda2
1
Parent(s):
f0d2b3d
Refactor to create videos
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -14,6 +14,7 @@ from typing import Union, List, Dict, Tuple
|
|
| 14 |
|
| 15 |
from modules.utils.paths import *
|
| 16 |
from modules.utils.image_helper import *
|
|
|
|
| 17 |
from modules.live_portrait.model_downloader import *
|
| 18 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
| 19 |
from modules.utils.camera import get_rotation_matrix
|
|
@@ -241,15 +242,21 @@ class LivePortraitInferencer:
|
|
| 241 |
raise
|
| 242 |
|
| 243 |
def create_video(self,
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
|
|
|
| 249 |
src_image_list: Optional[List[np.ndarray]] = None,
|
| 250 |
driving_images: Optional[List[np.ndarray]] = None,
|
| 251 |
progress: gr.Progress = gr.Progress()
|
| 252 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
src_length = 1
|
| 254 |
|
| 255 |
if src_image_list is not None:
|
|
@@ -322,7 +329,14 @@ class LivePortraitInferencer:
|
|
| 322 |
return None
|
| 323 |
|
| 324 |
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
def download_if_no_models(self,
|
| 328 |
model_type: str = ModelType.HUMAN.value,
|
|
|
|
| 14 |
|
| 15 |
from modules.utils.paths import *
|
| 16 |
from modules.utils.image_helper import *
|
| 17 |
+
from modules.utils.video_helper import *
|
| 18 |
from modules.live_portrait.model_downloader import *
|
| 19 |
from modules.live_portrait.live_portrait_wrapper import LivePortraitWrapper
|
| 20 |
from modules.utils.camera import get_rotation_matrix
|
|
|
|
| 242 |
raise
|
| 243 |
|
| 244 |
def create_video(self,
|
| 245 |
+
model_type: str = ModelType.HUMAN.value,
|
| 246 |
+
retargeting_eyes: bool = True,
|
| 247 |
+
retargeting_mouth: bool = True,
|
| 248 |
+
tracking_src_vid: bool = True,
|
| 249 |
+
animate_without_vid: bool = False,
|
| 250 |
+
crop_factor: float = 1.5,
|
| 251 |
src_image_list: Optional[List[np.ndarray]] = None,
|
| 252 |
driving_images: Optional[List[np.ndarray]] = None,
|
| 253 |
progress: gr.Progress = gr.Progress()
|
| 254 |
):
|
| 255 |
+
if self.pipeline is None or model_type != self.model_type:
|
| 256 |
+
self.load_models(
|
| 257 |
+
model_type=model_type
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
src_length = 1
|
| 261 |
|
| 262 |
if src_image_list is not None:
|
|
|
|
| 329 |
return None
|
| 330 |
|
| 331 |
out_imgs = torch.cat([pil2tensor(img_rgb) for img_rgb in out_list])
|
| 332 |
+
out_imgs = [tensor.permute(1, 2, 0).cpu().numpy() for tensor in out_imgs]
|
| 333 |
+
for img in out_imgs:
|
| 334 |
+
out_frame_path = get_auto_incremental_file_path(TEMP_VIDEO_OUT_FRAMES_DIR, "png")
|
| 335 |
+
save_image(img, out_frame_path)
|
| 336 |
+
|
| 337 |
+
video_path = create_video_from_frames(TEMP_VIDEO_OUT_FRAMES_DIR)
|
| 338 |
+
|
| 339 |
+
return video_path
|
| 340 |
|
| 341 |
def download_if_no_models(self,
|
| 342 |
model_type: str = ModelType.HUMAN.value,
|