Spaces:
Runtime error
Runtime error
modify
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ from enum import Enum
|
|
| 5 |
import db_examples
|
| 6 |
import cv2
|
| 7 |
|
| 8 |
-
|
| 9 |
|
| 10 |
from misc_utils.train_utils import unit_test_create_model
|
| 11 |
from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
|
|
@@ -28,7 +28,10 @@ from tqdm import tqdm
|
|
| 28 |
|
| 29 |
# 下载文件
|
| 30 |
os.makedirs('models', exist_ok=True)
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# if not os.path.exists(filename):
|
| 34 |
# original_path = os.getcwd()
|
|
@@ -157,15 +160,16 @@ config_path = 'configs/instruct_v2v_ic_gradio.yaml'
|
|
| 157 |
diffusion_model = unit_test_create_model(config_path).cuda()
|
| 158 |
|
| 159 |
# 加载模型检查点
|
| 160 |
-
ckpt_path = 'models/
|
| 161 |
-
|
|
|
|
| 162 |
diffusion_model.load_state_dict(ckpt, strict=False)
|
| 163 |
|
| 164 |
# import pdb; pdb.set_trace()
|
| 165 |
|
| 166 |
-
#
|
| 167 |
-
|
| 168 |
-
|
| 169 |
|
| 170 |
# import pdb; pdb.set_trace()
|
| 171 |
|
|
|
|
| 5 |
import db_examples
|
| 6 |
import cv2
|
| 7 |
|
| 8 |
+
from demo_utils1 import *
|
| 9 |
|
| 10 |
from misc_utils.train_utils import unit_test_create_model
|
| 11 |
from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images
|
|
|
|
| 28 |
|
| 29 |
# 下载文件
|
| 30 |
os.makedirs('models', exist_ok=True)
|
| 31 |
+
model_path = "models/relvid_mm_sd15_fbc_unet.pth"
|
| 32 |
+
|
| 33 |
+
if not os.path.exists(filename):
|
| 34 |
+
download_url_to_file(url='https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc_unet.pth', dst=model_path)
|
| 35 |
|
| 36 |
# if not os.path.exists(filename):
|
| 37 |
# original_path = os.getcwd()
|
|
|
|
| 160 |
diffusion_model = unit_test_create_model(config_path).cuda()
|
| 161 |
|
| 162 |
# 加载模型检查点
|
| 163 |
+
# ckpt_path = 'models/relvid_mm_sd15_fbc_unet.pth' #! change
|
| 164 |
+
# ckpt_path = 'tmp/pytorch_model.bin'
|
| 165 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
| 166 |
diffusion_model.load_state_dict(ckpt, strict=False)
|
| 167 |
|
| 168 |
# import pdb; pdb.set_trace()
|
| 169 |
|
| 170 |
+
# 更改全局临时目录
|
| 171 |
+
new_tmp_dir = "./demo/gradio_bg"
|
| 172 |
+
os.makedirs(new_tmp_dir, exist_ok=True)
|
| 173 |
|
| 174 |
# import pdb; pdb.set_trace()
|
| 175 |
|