yyang181 commited on
Commit
ad6add2
·
1 Parent(s): b9d7dc4

fix package error

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py CHANGED
@@ -12,6 +12,9 @@ import urllib.request
12
  from os import path
13
  import io
14
  from contextlib import redirect_stdout, redirect_stderr
 
 
 
15
 
16
  import gradio as gr
17
  import spaces
@@ -239,6 +242,61 @@ def build_args_list_for_test(d16_batch_path: str,
239
  args.extend([flag, str(v)])
240
  return args
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  # ----------------- GRADIO HANDLER -----------------
243
  @spaces.GPU(duration=600) # 确保 CUDA 初始化在此函数体内
244
  def gradio_infer(
@@ -249,6 +307,10 @@ def gradio_infer(
249
  num_proto, top_k, mem_every, deep_update,
250
  save_scores, flip, size, reverse
251
  ):
 
 
 
 
252
  # 1) 基本校验与临时目录
253
  if bw_video is None:
254
  return None, "请上传黑白视频 / Please upload a B&W video."
 
12
  from os import path
13
  import io
14
  from contextlib import redirect_stdout, redirect_stderr
15
+ import subprocess
16
+ import tempfile
17
+ import importlib
18
 
19
  import gradio as gr
20
  import spaces
 
242
  args.extend([flag, str(v)])
243
  return args
244
 
245
+ # ===== 新增:ZeroGPU 后按需安装 Pytorch-Correlation-extension =====
246
+ _CORR_OK_STAMP = path.join(os.getcwd(), ".corr_ext_installed")
247
+
248
+ def ensure_correlation_extension_installed():
249
+ """
250
+ 在 ZeroGPU 分配后调用(位于 @spaces.GPU 函数体内):
251
+ - 若已能 import 或存在本地 stamp,则直接返回
252
+ - 否则执行:
253
+ git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git
254
+ cd Pytorch-Correlation-extension && python setup.py install && cd ..
255
+ """
256
+ # 1) 尝试直接导入
257
+ try:
258
+ import spatial_correlation_sampler # noqa: F401
259
+ return
260
+ except Exception:
261
+ pass
262
+ # 2) 之前成功过(打过 stamp)
263
+ if path.exists(_CORR_OK_STAMP):
264
+ return
265
+
266
+ repo_url = "https://github.com/ClementPinard/Pytorch-Correlation-extension.git"
267
+ workdir = tempfile.mkdtemp(prefix="corr_ext_")
268
+ repo_dir = path.join(workdir, "Pytorch-Correlation-extension")
269
+ try:
270
+ print("[INFO] Installing Pytorch-Correlation-extension ...")
271
+ # clone
272
+ subprocess.run(
273
+ ["git", "clone", "--depth", "1", repo_url],
274
+ cwd=workdir, check=True
275
+ )
276
+ # build & install
277
+ subprocess.run(
278
+ [sys.executable, "setup.py", "install"],
279
+ cwd=repo_dir, check=True
280
+ )
281
+ # 验证
282
+ importlib.invalidate_caches()
283
+ import spatial_correlation_sampler # noqa: F401
284
+ # 打 stamp,避免下次重复
285
+ with open(_CORR_OK_STAMP, "w") as f:
286
+ f.write("ok")
287
+ print("[INFO] Pytorch-Correlation-extension installed successfully.")
288
+ except subprocess.CalledProcessError as e:
289
+ print(f"[WARN] Failed to build/install correlation extension: {e}")
290
+ print("You can still proceed if your pipeline doesn't use it.")
291
+ except Exception as e:
292
+ print(f"[WARN] Correlation extension install check failed: {e}")
293
+ finally:
294
+ # 清理临时目录
295
+ try:
296
+ shutil.rmtree(workdir, ignore_errors=True)
297
+ except Exception:
298
+ pass
299
+
300
  # ----------------- GRADIO HANDLER -----------------
301
  @spaces.GPU(duration=600) # 确保 CUDA 初始化在此函数体内
302
  def gradio_infer(
 
307
  num_proto, top_k, mem_every, deep_update,
308
  save_scores, flip, size, reverse
309
  ):
310
+ # <<< ZeroGPU 分配后:按需安装 Pytorch-Correlation-extension >>>
311
+ ensure_correlation_extension_installed()
312
+ # --------------------------------------------------------------
313
+
314
  # 1) 基本校验与临时目录
315
  if bw_video is None:
316
  return None, "请上传黑白视频 / Please upload a B&W video."