jhj0517
commited on
Commit
·
e52a9ca
1
Parent(s):
72a8d5d
Support zero gpu
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -11,6 +11,7 @@ from gradio_i18n import Translate, gettext as _
|
|
| 11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
| 12 |
from enum import Enum
|
| 13 |
from typing import Union
|
|
|
|
| 14 |
|
| 15 |
from modules.utils.paths import *
|
| 16 |
from modules.utils.image_helper import *
|
|
@@ -58,6 +59,7 @@ class LivePortraitInferencer:
|
|
| 58 |
self.psi_list = None
|
| 59 |
self.d_info = None
|
| 60 |
|
|
|
|
| 61 |
def load_models(self,
|
| 62 |
model_type: str = ModelType.HUMAN.value,
|
| 63 |
progress=gr.Progress()):
|
|
@@ -132,6 +134,7 @@ class LivePortraitInferencer:
|
|
| 132 |
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
|
| 133 |
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
|
| 134 |
|
|
|
|
| 135 |
def edit_expression(self,
|
| 136 |
model_type: str = ModelType.HUMAN.value,
|
| 137 |
rotate_pitch=0,
|
|
@@ -240,6 +243,7 @@ class LivePortraitInferencer:
|
|
| 240 |
except Exception as e:
|
| 241 |
raise
|
| 242 |
|
|
|
|
| 243 |
def create_video(self,
|
| 244 |
retargeting_eyes,
|
| 245 |
retargeting_mouth,
|
|
@@ -385,6 +389,7 @@ class LivePortraitInferencer:
|
|
| 385 |
download_model(model_path, model_url)
|
| 386 |
|
| 387 |
@staticmethod
|
|
|
|
| 388 |
def load_safe_tensor(model, file_path, is_stitcher=False):
|
| 389 |
def filter_stitcher(checkpoint, prefix):
|
| 390 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
|
@@ -399,6 +404,7 @@ class LivePortraitInferencer:
|
|
| 399 |
return model
|
| 400 |
|
| 401 |
@staticmethod
|
|
|
|
| 402 |
def get_device():
|
| 403 |
if torch.cuda.is_available():
|
| 404 |
return "cuda"
|
|
@@ -443,6 +449,7 @@ class LivePortraitInferencer:
|
|
| 443 |
|
| 444 |
return cmd_list, total_length
|
| 445 |
|
|
|
|
| 446 |
def get_face_bboxes(self, image_rgb):
|
| 447 |
pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
|
| 448 |
return pred[0].boxes.xyxy.cpu().numpy()
|
|
@@ -551,6 +558,7 @@ class LivePortraitInferencer:
|
|
| 551 |
cv2.INTER_LINEAR)
|
| 552 |
return new_img
|
| 553 |
|
|
|
|
| 554 |
def prepare_src_image(self, img):
|
| 555 |
h, w = img.shape[:2]
|
| 556 |
input_shape = [256,256]
|
|
|
|
| 11 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
| 12 |
from enum import Enum
|
| 13 |
from typing import Union
|
| 14 |
+
import spaces
|
| 15 |
|
| 16 |
from modules.utils.paths import *
|
| 17 |
from modules.utils.image_helper import *
|
|
|
|
| 59 |
self.psi_list = None
|
| 60 |
self.d_info = None
|
| 61 |
|
| 62 |
+
@spaces.GPU
|
| 63 |
def load_models(self,
|
| 64 |
model_type: str = ModelType.HUMAN.value,
|
| 65 |
progress=gr.Progress()):
|
|
|
|
| 134 |
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
|
| 135 |
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
|
| 136 |
|
| 137 |
+
@spaces.GPU
|
| 138 |
def edit_expression(self,
|
| 139 |
model_type: str = ModelType.HUMAN.value,
|
| 140 |
rotate_pitch=0,
|
|
|
|
| 243 |
except Exception as e:
|
| 244 |
raise
|
| 245 |
|
| 246 |
+
@spaces.GPU
|
| 247 |
def create_video(self,
|
| 248 |
retargeting_eyes,
|
| 249 |
retargeting_mouth,
|
|
|
|
| 389 |
download_model(model_path, model_url)
|
| 390 |
|
| 391 |
@staticmethod
|
| 392 |
+
@spaces.GPU
|
| 393 |
def load_safe_tensor(model, file_path, is_stitcher=False):
|
| 394 |
def filter_stitcher(checkpoint, prefix):
|
| 395 |
filtered_checkpoint = {key.replace(prefix + "_module.", ""): value for key, value in checkpoint.items() if
|
|
|
|
| 404 |
return model
|
| 405 |
|
| 406 |
@staticmethod
|
| 407 |
+
@spaces.GPU
|
| 408 |
def get_device():
|
| 409 |
if torch.cuda.is_available():
|
| 410 |
return "cuda"
|
|
|
|
| 449 |
|
| 450 |
return cmd_list, total_length
|
| 451 |
|
| 452 |
+
@spaces.GPU
|
| 453 |
def get_face_bboxes(self, image_rgb):
|
| 454 |
pred = self.detect_model(image_rgb, conf=0.7, device=self.device)
|
| 455 |
return pred[0].boxes.xyxy.cpu().numpy()
|
|
|
|
| 558 |
cv2.INTER_LINEAR)
|
| 559 |
return new_img
|
| 560 |
|
| 561 |
+
@spaces.GPU
|
| 562 |
def prepare_src_image(self, img):
|
| 563 |
h, w = img.shape[:2]
|
| 564 |
input_shape = [256,256]
|