Spaces:
Configuration error
Configuration error
optimize zero-gpu usage
Browse files
app.py
CHANGED
|
@@ -22,6 +22,8 @@ from PIL import Image
|
|
| 22 |
import cv2
|
| 23 |
import torch # used for cuda sync & empty_cache
|
| 24 |
|
|
|
|
|
|
|
| 25 |
# ----------------- BASIC INFO -----------------
|
| 26 |
CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 27 |
CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
|
@@ -299,6 +301,25 @@ def ensure_correlation_extension_installed():
|
|
| 299 |
|
| 300 |
# ----------------- GRADIO HANDLER -----------------
|
| 301 |
@spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
def gradio_infer(
|
| 303 |
debug_shapes,
|
| 304 |
bw_video, ref_image,
|
|
@@ -307,9 +328,7 @@ def gradio_infer(
|
|
| 307 |
num_proto, top_k, mem_every, deep_update,
|
| 308 |
save_scores, flip, size, reverse
|
| 309 |
):
|
| 310 |
-
|
| 311 |
-
ensure_correlation_extension_installed()
|
| 312 |
-
# --------------------------------------------------------------
|
| 313 |
|
| 314 |
# 1) 基本校验与临时目录
|
| 315 |
if bw_video is None:
|
|
@@ -403,10 +422,10 @@ def gradio_infer(
|
|
| 403 |
ensure_checkpoint()
|
| 404 |
|
| 405 |
# 8) 同进程调用 test.py
|
| 406 |
-
try:
|
| 407 |
-
|
| 408 |
-
except Exception as e:
|
| 409 |
-
|
| 410 |
|
| 411 |
args_list = build_args_list_for_test(
|
| 412 |
d16_batch_path=input_root, # 指向 input_video 根
|
|
@@ -415,17 +434,19 @@ def gradio_infer(
|
|
| 415 |
cfg=user_config
|
| 416 |
)
|
| 417 |
|
| 418 |
-
buf = io.StringIO()
|
| 419 |
-
try:
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
except Exception as e:
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
| 429 |
|
| 430 |
# 在合成 mp4 之前:清空 CUDA(防止显存占用)
|
| 431 |
try:
|
|
|
|
| 22 |
import cv2
|
| 23 |
import torch # used for cuda sync & empty_cache
|
| 24 |
|
| 25 |
+
import test
|
| 26 |
+
|
| 27 |
# ----------------- BASIC INFO -----------------
|
| 28 |
CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
| 29 |
CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
|
|
|
|
| 301 |
|
| 302 |
# ----------------- GRADIO HANDLER -----------------
|
| 303 |
@spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
|
| 304 |
+
def inference_with_gpu(args_list):
|
| 305 |
+
# <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>>
|
| 306 |
+
ensure_correlation_extension_installed()
|
| 307 |
+
# --------------------------------------------------------------
|
| 308 |
+
|
| 309 |
+
buf = io.StringIO()
|
| 310 |
+
try:
|
| 311 |
+
with redirect_stdout(buf), redirect_stderr(buf):
|
| 312 |
+
entry = getattr(test, "run_cli", None)
|
| 313 |
+
if entry is None or not callable(entry):
|
| 314 |
+
raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
|
| 315 |
+
entry(args_list)
|
| 316 |
+
log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
|
| 317 |
+
except Exception as e:
|
| 318 |
+
log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
|
| 319 |
+
|
| 320 |
+
return log
|
| 321 |
+
|
| 322 |
+
|
| 323 |
def gradio_infer(
|
| 324 |
debug_shapes,
|
| 325 |
bw_video, ref_image,
|
|
|
|
| 328 |
num_proto, top_k, mem_every, deep_update,
|
| 329 |
save_scores, flip, size, reverse
|
| 330 |
):
|
| 331 |
+
|
|
|
|
|
|
|
| 332 |
|
| 333 |
# 1) 基本校验与临时目录
|
| 334 |
if bw_video is None:
|
|
|
|
| 422 |
ensure_checkpoint()
|
| 423 |
|
| 424 |
# 8) 同进程调用 test.py
|
| 425 |
+
# try:
|
| 426 |
+
# import test # 确保 test.py 同目录且提供 run_cli(args_list)
|
| 427 |
+
# except Exception as e:
|
| 428 |
+
# return None, f"导入 test.py 失败 / Failed to import test.py:\n{e}"
|
| 429 |
|
| 430 |
args_list = build_args_list_for_test(
|
| 431 |
d16_batch_path=input_root, # 指向 input_video 根
|
|
|
|
| 434 |
cfg=user_config
|
| 435 |
)
|
| 436 |
|
| 437 |
+
# buf = io.StringIO()
|
| 438 |
+
# try:
|
| 439 |
+
# with redirect_stdout(buf), redirect_stderr(buf):
|
| 440 |
+
# entry = getattr(test, "run_cli", None)
|
| 441 |
+
# if entry is None or not callable(entry):
|
| 442 |
+
# raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
|
| 443 |
+
# entry(args_list)
|
| 444 |
+
# log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
|
| 445 |
+
# except Exception as e:
|
| 446 |
+
# log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
|
| 447 |
+
# return None, log
|
| 448 |
+
|
| 449 |
+
log = inference_with_gpu(args_list)
|
| 450 |
|
| 451 |
# 在合成 mp4 之前:清空 CUDA(防止显存占用)
|
| 452 |
try:
|