jhj0517
commited on
Commit
·
0e976e4
1
Parent(s):
ccfe1b7
Update animal model name
Browse files
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -129,7 +129,8 @@ class LivePortraitInferencer:
|
|
| 129 |
self.stitching_retargeting_module
|
| 130 |
)
|
| 131 |
|
| 132 |
-
|
|
|
|
| 133 |
|
| 134 |
def edit_expression(self,
|
| 135 |
model_type: str = ModelType.HUMAN.value,
|
|
@@ -375,8 +376,8 @@ class LivePortraitInferencer:
|
|
| 375 |
for model_name, model_url in models_urls_dic.items():
|
| 376 |
if model_url.endswith(".pt"):
|
| 377 |
model_name += ".pt"
|
| 378 |
-
|
| 379 |
-
|
| 380 |
else:
|
| 381 |
model_name += ".safetensors"
|
| 382 |
model_path = os.path.join(model_dir, model_name)
|
|
|
|
| 129 |
self.stitching_retargeting_module
|
| 130 |
)
|
| 131 |
|
| 132 |
+
det_model_name = "yolo_v5s_animal_det" if model_type == ModelType.ANIMAL else "face_yolov8n"
|
| 133 |
+
self.detect_model = YOLO(MODEL_PATHS[det_model_name]).to(self.device)
|
| 134 |
|
| 135 |
def edit_expression(self,
|
| 136 |
model_type: str = ModelType.HUMAN.value,
|
|
|
|
| 376 |
for model_name, model_url in models_urls_dic.items():
|
| 377 |
if model_url.endswith(".pt"):
|
| 378 |
model_name += ".pt"
|
| 379 |
+
elif model_url.endswith(".n2x"):
|
| 380 |
+
model_name += ".n2x"
|
| 381 |
else:
|
| 382 |
model_name += ".safetensors"
|
| 383 |
model_path = os.path.join(model_dir, model_name)
|