yjh001's picture
Update app.py
5f7fb5b verified
#!/usr/bin/env python3
"""Gradio demo for MetricAnything DepthMap."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Tuple
import gradio as gr
import matplotlib
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2
from depth_model import MetricAnythingDepthMap
try:
import spaces
SPACES_AVAILABLE = True
except Exception:
SPACES_AVAILABLE = False
EXAMPLES_DIR = Path(__file__).parent / "examples"
MODEL_ID = "yjh001/metricanything_student_depthmap"
MODEL_FILENAME = "student_depthmap.pt"
MAX_DEPTH = 200.0
def list_examples() -> list[Path]:
exts = {".png", ".jpg", ".jpeg"}
if not EXAMPLES_DIR.exists():
return []
return sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in exts])
def read_intrinsics(json_path: Path) -> float | None:
if not json_path.exists():
return None
data = json.loads(json_path.read_text())
cam_in = data.get("cam_in")
if cam_in is None:
return None
if isinstance(cam_in, (list, tuple)) and len(cam_in) > 0:
return float(cam_in[0])
if isinstance(cam_in, dict):
for key in ("fx", "f_x", "focal_length", "focal_length_px"):
if key in cam_in:
return float(cam_in[key])
return None
def make_transform() -> v2.Compose:
return v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def colorize_depth(depth: np.ndarray, max_depth: float = MAX_DEPTH, cmap: str = "turbo_r") -> np.ndarray:
"""Inverse-depth visualization in a 0–max_depth meter range; invalid/far pixels are white."""
valid = np.isfinite(depth) & (depth > 0) & (depth <= max_depth)
if not np.any(valid):
return np.full((*depth.shape, 3), 255, dtype=np.uint8)
disp = np.where(valid, 1.0 / depth, np.nan)
min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
disp = (disp - min_disp) / (max_disp - min_disp) if max_disp > min_disp else disp * 0.0
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0.0)
colored = (colored.clip(0.0, 1.0) * 255).astype(np.uint8)
colored[~valid] = 255
return np.ascontiguousarray(colored)
def prepare_focal(image: Image.Image, image_path: Path | None) -> Tuple[float, str, gr.Slider]:
width = image.width
fx = None
info = ""
if image_path is not None:
fx = read_intrinsics(image_path.with_suffix(".json"))
if fx is not None:
info = f"Intrinsics found. Using focal length (pixels): **{fx:.2f}**."
else:
info = f"No intrinsics found. Using image width (W={width}) as focal length (pixels)."
else:
info = f"No intrinsics found. Using image width (W={width}) as focal length (pixels)."
if fx is None:
fx = float(width)
# slider = gr.Slider.update(value=fx, minimum=1, maximum=max(2000, int(width * 2)), step=1)
slider = gr.update(value=fx, minimum=1, maximum=max(2000, int(width * 2)), step=1)
return fx, info, slider
def select_example(example_paths: list[str], evt: gr.SelectData):
path = Path(example_paths[evt.index])
image = Image.open(path).convert("RGB")
_, info, slider = prepare_focal(image, path)
return image, slider, info, "example"
def on_input_change(image: Image.Image | None, source: str):
if image is None:
# return gr.Slider.update(), gr.update(), ""
return gr.update(), gr.update(), ""
if source == "example":
# return gr.Slider.update(), gr.update(), ""
return gr.update(), gr.update(), ""
_, info, slider = prepare_focal(image, None)
return slider, info, ""
def load_model() -> MetricAnythingDepthMap:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = MetricAnythingDepthMap.from_pretrained(
MODEL_ID,
model_kwargs={"device": device},
filename=MODEL_FILENAME,
)
model.eval()
return model
TRANSFORM = make_transform()
MODEL = load_model()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@torch.no_grad()
@(spaces.GPU if SPACES_AVAILABLE else (lambda f: f))
def run_inference(image: Image.Image | None, focal_px: float):
if image is None:
return None, "Please provide an input image."
tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE)
prediction = MODEL.infer(tensor, f_px=float(focal_px))
depth = prediction["depth"].detach().cpu().numpy().squeeze()
vis = colorize_depth(depth, max_depth=MAX_DEPTH)
valid = np.isfinite(depth) & (depth > 0) & (depth <= MAX_DEPTH)
if np.any(valid):
min_d = float(depth[valid].min())
max_d = float(depth[valid].max())
stats = f"Depth range (0–{MAX_DEPTH:.0f} m): min={min_d:.2f} m, max={max_d:.2f} m"
else:
stats = f"No valid depth in 0–{MAX_DEPTH:.0f} m range."
return vis, stats
def build_demo() -> gr.Blocks:
example_paths = list_examples()
gallery_items = [(str(p), p.name) for p in example_paths]
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# MetricAnything DepthMap")
gr.Markdown("Select an example or upload your own image to estimate metric depth.")
with gr.Row():
with gr.Column(scale=3):
gallery = gr.Gallery(
value=gallery_items,
label="Examples",
columns=4,
rows=2,
height=220,
)
input_image = gr.Image(type="pil", label="Input", height=320)
focal_slider = gr.Slider(label="Focal length (pixels)", minimum=1, maximum=4000, step=1, value=1000)
info = gr.Markdown("Select an example or upload an image.")
run_btn = gr.Button("Run")
with gr.Column(scale=4):
output_image = gr.Image(type="numpy", label="Depth (visualized)", height=420)
output_stats = gr.Markdown("")
example_state = gr.State([str(p) for p in example_paths])
source_state = gr.State("")
if example_paths:
gallery.select(
select_example,
inputs=[example_state],
outputs=[input_image, focal_slider, info, source_state],
)
input_image.change(
on_input_change,
inputs=[input_image, source_state],
outputs=[focal_slider, info, source_state],
)
run_btn.click(
run_inference,
inputs=[input_image, focal_slider],
outputs=[output_image, output_stats],
)
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()