jhj0517
commited on
Commit
·
d10784a
1
Parent(s):
ad23aea
Fix wrapper
Browse files
modules/image_restoration/real_esrgan/wrapper/real_esrganer.py
CHANGED
|
@@ -47,8 +47,7 @@ class RealESRGANer():
|
|
| 47 |
else:
|
| 48 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
| 49 |
|
| 50 |
-
|
| 51 |
-
loadnet = self.dni(model_path[0], model_path[1], dni_weight)
|
| 52 |
|
| 53 |
# prefer to use params_ema
|
| 54 |
if 'params_ema' in loadnet:
|
|
|
|
| 47 |
else:
|
| 48 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
|
| 49 |
|
| 50 |
+
loadnet = torch.load(model_path, map_location=torch.device('cpu'))
|
|
|
|
| 51 |
|
| 52 |
# prefer to use params_ema
|
| 53 |
if 'params_ema' in loadnet:
|
modules/live_portrait/live_portrait_inferencer.py
CHANGED
|
@@ -1,17 +1,11 @@
|
|
| 1 |
import logging
|
| 2 |
-
import os
|
| 3 |
-
import cv2
|
| 4 |
import time
|
| 5 |
import copy
|
| 6 |
import dill
|
| 7 |
-
import torch
|
| 8 |
from ultralytics import YOLO
|
| 9 |
import safetensors.torch
|
| 10 |
import gradio as gr
|
| 11 |
-
from gradio_i18n import Translate, gettext as _
|
| 12 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
| 13 |
-
from enum import Enum
|
| 14 |
-
from typing import Union, List, Dict, Tuple
|
| 15 |
|
| 16 |
from modules.utils.paths import *
|
| 17 |
from modules.utils.image_helper import *
|
|
@@ -27,7 +21,7 @@ from modules.live_portrait.warping_network import WarpingNetwork
|
|
| 27 |
from modules.live_portrait.motion_extractor import MotionExtractor
|
| 28 |
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
| 29 |
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
| 30 |
-
from modules.image_restoration.real_esrgan_inferencer import RealESRGANInferencer
|
| 31 |
|
| 32 |
|
| 33 |
class LivePortraitInferencer:
|
|
|
|
| 1 |
import logging
|
|
|
|
|
|
|
| 2 |
import time
|
| 3 |
import copy
|
| 4 |
import dill
|
|
|
|
| 5 |
from ultralytics import YOLO
|
| 6 |
import safetensors.torch
|
| 7 |
import gradio as gr
|
|
|
|
| 8 |
from ultralytics.utils import LOGGER as ultralytics_logger
|
|
|
|
|
|
|
| 9 |
|
| 10 |
from modules.utils.paths import *
|
| 11 |
from modules.utils.image_helper import *
|
|
|
|
| 21 |
from modules.live_portrait.motion_extractor import MotionExtractor
|
| 22 |
from modules.live_portrait.appearance_feature_extractor import AppearanceFeatureExtractor
|
| 23 |
from modules.live_portrait.stitching_retargeting_network import StitchingRetargetingNetwork
|
| 24 |
+
from modules.image_restoration.real_esrgan.real_esrgan_inferencer import RealESRGANInferencer
|
| 25 |
|
| 26 |
|
| 27 |
class LivePortraitInferencer:
|