GeoRemover / app.py
zixinz
depth estimatro
69b2678
raw
history blame
3.05 kB
# app.py
import os
import pathlib
import subprocess
import gradio as gr
import spaces
import torch
from PIL import Image
BASE_DIR = pathlib.Path(__file__).resolve().parent
SCRIPT_DIR = BASE_DIR / "code_depth"
GET_WEIGHTS_SH = SCRIPT_DIR / "get_weights.sh"
# 让我们能 import 到 code_depth/depth_infer.py
import sys
if str(SCRIPT_DIR) not in sys.path:
sys.path.append(str(SCRIPT_DIR))
from depth_infer import DepthModel # noqa
def _ensure_executable(p: pathlib.Path):
if not p.exists():
raise FileNotFoundError(f"Not found: {p}")
os.chmod(p, os.stat(p).st_mode | 0o111)
def ensure_weights():
"""在 code_depth 目录下运行你的 get_weights.sh。"""
_ensure_executable(GET_WEIGHTS_SH)
subprocess.run(
["bash", str(GET_WEIGHTS_SH)],
check=True,
cwd=str(SCRIPT_DIR),
env={**os.environ, "HF_HUB_DISABLE_TELEMETRY": "1"},
)
ckpt_dir = SCRIPT_DIR / "checkpoints"
if not ckpt_dir.exists():
raise RuntimeError("weights download script ran but checkpoints/ not found")
return str(ckpt_dir)
# 启动时下载权重(不开持久化时,若环境重建会再次下载)
try:
CKPT_DIR = ensure_weights()
print(f"✅ Weights ready in: {CKPT_DIR}")
except Exception as e:
print(f"⚠️ Failed to prepare weights: {e}")
# 模型缓存(按 encoder 复用)
_MODELS: dict[str, DepthModel] = {}
def get_model(encoder: str) -> DepthModel:
if encoder not in _MODELS:
_MODELS[encoder] = DepthModel(BASE_DIR, encoder=encoder)
return _MODELS[encoder]
@spaces.GPU
def infer_depth(
image: Image.Image,
encoder: str = "vitl",
max_res: int = 1280,
input_size: int = 518,
fp32: bool = False,
grayscale: bool = False,
) -> Image.Image:
# 这里才真正触发 CUDA 设备占用
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[infer] device={device}, encoder={encoder}, max_res={max_res}, input_size={input_size}, fp32={fp32}, gray={grayscale}")
model = get_model(encoder)
return model.infer(image, max_res=max_res, input_size=input_size, fp32=fp32, grayscale=grayscale)
with gr.Blocks() as demo:
gr.Markdown("## GeoRemover · Depth Preview (Video-Depth-Anything)")
with gr.Row():
with gr.Column():
inp = gr.Image(label="Upload image", type="pil")
encoder = gr.Dropdown(["vits", "vitl"], value="vitl", label="Encoder")
max_res = gr.Slider(512, 2048, value=1280, step=64, label="Max resolution")
input_size = gr.Slider(256, 1024, value=518, step=2, label="Model input_size")
fp32 = gr.Checkbox(False, label="Use FP32 (default FP16)")
gray = gr.Checkbox(False, label="Grayscale depth")
btn = gr.Button("Run")
with gr.Column():
out = gr.Image(label="Depth visualization")
btn.click(fn=infer_depth, inputs=[inp, encoder, max_res, input_size, fp32, gray], outputs=[out])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)