Spaces:
Running
on
Zero
Running
on
Zero
| # app.py | |
| import os | |
| import pathlib | |
| import subprocess | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| BASE_DIR = pathlib.Path(__file__).resolve().parent | |
| SCRIPT_DIR = BASE_DIR / "code_depth" | |
| GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh" | |
| # 让我们能 import 到 code_depth/depth_infer.py | |
| import sys | |
| if str(SCRIPT_DIR) not in sys.path: | |
| sys.path.append(str(SCRIPT_DIR)) | |
| from depth_infer import DepthModel # noqa | |
| def _ensure_executable(p: pathlib.Path): | |
| if not p.exists(): | |
| raise FileNotFoundError(f"Not found: {p}") | |
| os.chmod(p, os.stat(p).st_mode | 0o111) | |
| def ensure_weights(): | |
| """在 code_depth 目录下运行你的 get_weights.sh。""" | |
| _ensure_executable(GET_WEIGHTS_SH) | |
| subprocess.run( | |
| ["bash", str(GET_WEIGHTS_SH)], | |
| check=True, | |
| cwd=str(SCRIPT_DIR), | |
| env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"}, | |
| ) | |
| ckpt_dir = SCRIPT_DIR / "checkpoints" | |
| if not ckpt_dir.exists(): | |
| raise RuntimeError("weights download script ran but checkpoints/ not found") | |
| return str(ckpt_dir) | |
| # 启动时下载权重(不开持久化时,若环境重建会再次下载) | |
| try: | |
| CKPT_DIR = ensure_weights() | |
| print(f"✅ Weights ready in: {CKPT_DIR}") | |
| except Exception as e: | |
| print(f"⚠️ Failed to prepare weights: {e}") | |
| # 模型缓存(按 encoder 复用) | |
| _MODELS: dict[str, DepthModel] = {} | |
| def get_model(encoder: str) -> DepthModel: | |
| if encoder not in _MODELS: | |
| _MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder) | |
| return _MODELS[encoder] | |
| def infer_depth( | |
| image: Image.Image, | |
| encoder: str = "vitl", | |
| max_res: int = 1280, | |
| input_size: int = 518, | |
| fp32: bool = False, | |
| grayscale: bool = False, | |
| ) -> Image.Image: | |
| # 这里才真正触发 CUDA 设备占用 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}") | |
| model = get_model(encoder) | |
| return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Image(label="Upload image", type="pil") | |
| encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder") | |
| max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution") | |
| input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size") | |
| fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)") | |
| gray = gr.Checkbox(False, label="Grayscale depth") | |
| btn = gr.Button("Run") | |
| with gr.Column(): | |
| out = gr.Image(label="Depth visualization") | |
| btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |