yyang181 commited on
Commit
874b84a
·
1 Parent(s): 2da3e7c

optimize zero-gpu usage

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -22,6 +22,8 @@ from PIL import Image
22
  import cv2
23
  import torch # used for cuda sync & empty_cache
24
 
 
 
25
  # ----------------- BASIC INFO -----------------
26
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
27
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
@@ -299,6 +301,25 @@ def ensure_correlation_extension_installed():
299
 
300
  # ----------------- GRADIO HANDLER -----------------
301
  @spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  def gradio_infer(
303
  debug_shapes,
304
  bw_video, ref_image,
@@ -307,9 +328,7 @@ def gradio_infer(
307
  num_proto, top_k, mem_every, deep_update,
308
  save_scores, flip, size, reverse
309
  ):
310
- # <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>>
311
- ensure_correlation_extension_installed()
312
- # --------------------------------------------------------------
313
 
314
  # 1) 基本校验与临时目录
315
  if bw_video is None:
@@ -403,10 +422,10 @@ def gradio_infer(
403
  ensure_checkpoint()
404
 
405
  # 8) 同进程调用 test.py
406
- try:
407
- import test # 确保 test.py 同目录且提供 run_cli(args_list)
408
- except Exception as e:
409
- return None, f"导入 test.py 失败 / Failed to import test.py:\n{e}"
410
 
411
  args_list = build_args_list_for_test(
412
  d16_batch_path=input_root, # 指向 input_video 根
@@ -415,17 +434,19 @@ def gradio_infer(
415
  cfg=user_config
416
  )
417
 
418
- buf = io.StringIO()
419
- try:
420
- with redirect_stdout(buf), redirect_stderr(buf):
421
- entry = getattr(test, "run_cli", None)
422
- if entry is None or not callable(entry):
423
- raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
424
- entry(args_list)
425
- log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
426
- except Exception as e:
427
- log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
428
- return None, log
 
 
429
 
430
  # 在合成 mp4 之前:清空 CUDA(防止显存占用)
431
  try:
 
22
  import cv2
23
  import torch # used for cuda sync & empty_cache
24
 
25
+ import test
26
+
27
  # ----------------- BASIC INFO -----------------
28
  CHECKPOINT_URL = "https://github.com/yyang181/colormnet/releases/download/v0.1/DINOv2FeatureV6_LocalAtten_s2_154000.pth"
29
  CHECKPOINT_LOCAL = "DINOv2FeatureV6_LocalAtten_s2_154000.pth"
 
301
 
302
  # ----------------- GRADIO HANDLER -----------------
303
  @spaces.GPU(duration=120) # 确保 CUDA 初始化在此函数体内
304
+ def inference_with_gpu(args_list):
305
+ # <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>>
306
+ ensure_correlation_extension_installed()
307
+ # --------------------------------------------------------------
308
+
309
+ buf = io.StringIO()
310
+ try:
311
+ with redirect_stdout(buf), redirect_stderr(buf):
312
+ entry = getattr(test, "run_cli", None)
313
+ if entry is None or not callable(entry):
314
+ raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
315
+ entry(args_list)
316
+ log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
317
+ except Exception as e:
318
+ log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
319
+
320
+ return log
321
+
322
+
323
  def gradio_infer(
324
  debug_shapes,
325
  bw_video, ref_image,
 
328
  num_proto, top_k, mem_every, deep_update,
329
  save_scores, flip, size, reverse
330
  ):
331
+
 
 
332
 
333
  # 1) 基本校验与临时目录
334
  if bw_video is None:
 
422
  ensure_checkpoint()
423
 
424
  # 8) 同进程调用 test.py
425
+ # try:
426
+ # import test # 确保 test.py 同目录且提供 run_cli(args_list)
427
+ # except Exception as e:
428
+ # return None, f"导入 test.py 失败 / Failed to import test.py:\n{e}"
429
 
430
  args_list = build_args_list_for_test(
431
  d16_batch_path=input_root, # 指向 input_video 根
 
434
  cfg=user_config
435
  )
436
 
437
+ # buf = io.StringIO()
438
+ # try:
439
+ # with redirect_stdout(buf), redirect_stderr(buf):
440
+ # entry = getattr(test, "run_cli", None)
441
+ # if entry is None or not callable(entry):
442
+ # raise RuntimeError("test.py 未提供可调用的 run_cli(args_list) 接口。")
443
+ # entry(args_list)
444
+ # log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}"
445
+ # except Exception as e:
446
+ # log = f"Args: {' '.join(args_list)}\n\n{buf.getvalue()}\n\nERROR: {e}"
447
+ # return None, log
448
+
449
+ log = inference_with_gpu(args_list)
450
 
451
  # 在合成 mp4 之前:清空 CUDA(防止显存占用)
452
  try: