zixinz commited on
Commit
69b2678
·
1 Parent(s): a52705d

depth estimatro

Browse files
Files changed (2) hide show
  1. app.py +56 -25
  2. code_depth/depth_infer.py +87 -0
app.py CHANGED
@@ -1,28 +1,31 @@
 
1
  import os
2
  import pathlib
3
  import subprocess
4
  import gradio as gr
5
  import spaces
6
  import torch
 
7
 
8
- # ---------- 权重下载:强制在 code_depth 下执行你的脚本 ----------
9
  BASE_DIR = pathlib.Path(__file__).resolve().parent
10
  SCRIPT_DIR = BASE_DIR / "code_depth"
11
  GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
12
 
13
- def ensure_executable(path: pathlib.Path):
14
- if not path.exists():
15
- raise FileNotFoundError(f"Download script not found: {path}")
16
- os.chmod(path, os.stat(path).st_mode | 0o111)
17
-
18
- def ensure_weights() -> str:
19
- """
20
- code_depth 目录下运行 get_weights.sh。
21
- 该脚本会在 code_depth/ 下创建 checkpoints/ 并下载权重。
22
- 返回绝对路径:<repo_root>/code_depth/checkpoints
23
- """
24
- ensure_executable(GET_WEIGHTS_SH)
25
- # 你脚本的工作目录需要是 code_depth
 
 
26
  subprocess.run(
27
  ["bash", str(GET_WEIGHTS_SH)],
28
  check=True,
@@ -30,27 +33,55 @@ def ensure_weights() -> str:
30
  env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
31
  )
32
  ckpt_dir = SCRIPT_DIR / "checkpoints"
 
 
33
  return str(ckpt_dir)
34
 
35
- # 启动时先拉权重(不开 Persistent Storage 时,重建环境会清空;重启后会自动再拉一次)
36
  try:
37
  CKPT_DIR = ensure_weights()
38
  print(f"✅ Weights ready in: {CKPT_DIR}")
39
  except Exception as e:
40
  print(f"⚠️ Failed to prepare weights: {e}")
41
- CKPT_DIR = str(SCRIPT_DIR / "checkpoints") # 仍然给个路径,后续可检查是否存在
42
 
43
- # ---------- Gradio 推理函数 ----------
 
 
 
 
 
 
 
44
  @spaces.GPU
45
- def greet(n: float):
46
- # 在 GPU worker 里拿 device
 
 
 
 
 
 
 
47
  device = "cuda" if torch.cuda.is_available() else "cpu"
48
- zero = torch.tensor([0.0], device=device)
49
- # 仅示例输出,你可以在这里用 CKPT_DIR 加载你的模型
50
- print(f"Device in greet(): {device}")
51
- print(f"Using checkpoints from: {CKPT_DIR}")
52
- return f"Hello {(zero + n).item()} Tensor (device={device})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- demo = gr.Interface(fn=greet, inputs=gr.Number(label="n"), outputs=gr.Text())
55
  if __name__ == "__main__":
56
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # app.py
2
  import os
3
  import pathlib
4
  import subprocess
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from PIL import Image
9
 
 
10
  BASE_DIR = pathlib.Path(__file__).resolve().parent
11
  SCRIPT_DIR = BASE_DIR / "code_depth"
12
  GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
13
 
14
+ # 让我们能 import 到 code_depth/depth_infer.py
15
+ import sys
16
+ if str(SCRIPT_DIR) not in sys.path:
17
+ sys.path.append(str(SCRIPT_DIR))
18
+
19
+ from depth_infer import DepthModel # noqa
20
+
21
+ def _ensure_executable(p: pathlib.Path):
22
+ if not p.exists():
23
+ raise FileNotFoundError(f"Not found: {p}")
24
+ os.chmod(p, os.stat(p).st_mode | 0o111)
25
+
26
+ def ensure_weights():
27
+ """在 code_depth 目录下运行你的 get_weights.sh。"""
28
+ _ensure_executable(GET_WEIGHTS_SH)
29
  subprocess.run(
30
  ["bash", str(GET_WEIGHTS_SH)],
31
  check=True,
 
33
  env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
34
  )
35
  ckpt_dir = SCRIPT_DIR / "checkpoints"
36
+ if not ckpt_dir.exists():
37
+ raise RuntimeError("weights download script ran but checkpoints/ not found")
38
  return str(ckpt_dir)
39
 
40
+ # 启动时下载权重(不开持久化时,若环境重建会再次下载)
41
  try:
42
  CKPT_DIR = ensure_weights()
43
  print(f"✅ Weights ready in: {CKPT_DIR}")
44
  except Exception as e:
45
  print(f"⚠️ Failed to prepare weights: {e}")
 
46
 
47
+ # 模型缓存(按 encoder 复用)
48
+ _MODELS: dict[str, DepthModel] = {}
49
+
50
+ def get_model(encoder: str) -> DepthModel:
51
+ if encoder not in _MODELS:
52
+ _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
53
+ return _MODELS[encoder]
54
+
55
  @spaces.GPU
56
+ def infer_depth(
57
+ image: Image.Image,
58
+ encoder: str = "vitl",
59
+ max_res: int = 1280,
60
+ input_size: int = 518,
61
+ fp32: bool = False,
62
+ grayscale: bool = False,
63
+ ) -> Image.Image:
64
+ # 这里才真正触发 CUDA 设备占用
65
  device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}")
67
+ model = get_model(encoder)
68
+ return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale)
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)")
72
+ with gr.Row():
73
+ with gr.Column():
74
+ inp = gr.Image(label="Upload image", type="pil")
75
+ encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder")
76
+ max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution")
77
+ input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size")
78
+ fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)")
79
+ gray = gr.Checkbox(False, label="Grayscale depth")
80
+ btn = gr.Button("Run")
81
+ with gr.Column():
82
+ out = gr.Image(label="Depth visualization")
83
+
84
+ btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out])
85
 
 
86
  if __name__ == "__main__":
87
  demo.launch(server_name="0.0.0.0", server_port=7860)
code_depth/depth_infer.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import pathlib
4
+ import subprocess
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from PIL import Image
9
+
10
+ BASE_DIR = pathlib.Path(__file__).resolve().parent
11
+ SCRIPT_DIR = BASE_DIR / "code_depth"
12
+ GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
13
+
14
+ # 让我们能 import 到 code_depth/depth_infer.py
15
+ import sys
16
+ if str(SCRIPT_DIR) not in sys.path:
17
+ sys.path.append(str(SCRIPT_DIR))
18
+
19
+ from depth_infer import DepthModel # noqa
20
+
21
+ def _ensure_executable(p: pathlib.Path):
22
+ if not p.exists():
23
+ raise FileNotFoundError(f"Not found: {p}")
24
+ os.chmod(p, os.stat(p).st_mode | 0o111)
25
+
26
+ def ensure_weights():
27
+ """在 code_depth 目录下运行你的 get_weights.sh。"""
28
+ _ensure_executable(GET_WEIGHTS_SH)
29
+ subprocess.run(
30
+ ["bash", str(GET_WEIGHTS_SH)],
31
+ check=True,
32
+ cwd=str(SCRIPT_DIR),
33
+ env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
34
+ )
35
+ ckpt_dir = SCRIPT_DIR / "checkpoints"
36
+ if not ckpt_dir.exists():
37
+ raise RuntimeError("weights download script ran but checkpoints/ not found")
38
+ return str(ckpt_dir)
39
+
40
+ # 启动时下载权重(不开持久化时,若环境重建会再次下载)
41
+ try:
42
+ CKPT_DIR = ensure_weights()
43
+ print(f"✅ Weights ready in: {CKPT_DIR}")
44
+ except Exception as e:
45
+ print(f"⚠️ Failed to prepare weights: {e}")
46
+
47
+ # 模型缓存(按 encoder 复用)
48
+ _MODELS: dict[str, DepthModel] = {}
49
+
50
+ def get_model(encoder: str) -> DepthModel:
51
+ if encoder not in _MODELS:
52
+ _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
53
+ return _MODELS[encoder]
54
+
55
+ @spaces.GPU
56
+ def infer_depth(
57
+ image: Image.Image,
58
+ encoder: str = "vitl",
59
+ max_res: int = 1280,
60
+ input_size: int = 518,
61
+ fp32: bool = False,
62
+ grayscale: bool = False,
63
+ ) -> Image.Image:
64
+ # 这里才真正触发 CUDA 设备占用
65
+ device = "cuda" if torch.cuda.is_available() else "cpu"
66
+ print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}")
67
+ model = get_model(encoder)
68
+ return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale)
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)")
72
+ with gr.Row():
73
+ with gr.Column():
74
+ inp = gr.Image(label="Upload image", type="pil")
75
+ encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder")
76
+ max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution")
77
+ input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size")
78
+ fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)")
79
+ gray = gr.Checkbox(False, label="Grayscale depth")
80
+ btn = gr.Button("Run")
81
+ with gr.Column():
82
+ out = gr.Image(label="Depth visualization")
83
+
84
+ btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out])
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch(server_name="0.0.0.0", server_port=7860)