| |
| """Gradio demo for MetricAnything DepthMap.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Any, Tuple |
|
|
| import gradio as gr |
| import matplotlib |
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision.transforms import v2 |
|
|
| from depth_model import MetricAnythingDepthMap |
|
|
| try: |
| import spaces |
|
|
| SPACES_AVAILABLE = True |
| except Exception: |
| SPACES_AVAILABLE = False |
|
|
|
|
| EXAMPLES_DIR = Path(__file__).parent / "examples" |
| MODEL_ID = "yjh001/metricanything_student_depthmap" |
| MODEL_FILENAME = "student_depthmap.pt" |
| MAX_DEPTH = 200.0 |
|
|
|
|
| def list_examples() -> list[Path]: |
| exts = {".png", ".jpg", ".jpeg"} |
| if not EXAMPLES_DIR.exists(): |
| return [] |
| return sorted([p for p in EXAMPLES_DIR.iterdir() if p.suffix.lower() in exts]) |
|
|
|
|
| def read_intrinsics(json_path: Path) -> float | None: |
| if not json_path.exists(): |
| return None |
| data = json.loads(json_path.read_text()) |
| cam_in = data.get("cam_in") |
| if cam_in is None: |
| return None |
| if isinstance(cam_in, (list, tuple)) and len(cam_in) > 0: |
| return float(cam_in[0]) |
| if isinstance(cam_in, dict): |
| for key in ("fx", "f_x", "focal_length", "focal_length_px"): |
| if key in cam_in: |
| return float(cam_in[key]) |
| return None |
|
|
|
|
| def make_transform() -> v2.Compose: |
| return v2.Compose([ |
| v2.ToImage(), |
| v2.ToDtype(torch.float32, scale=True), |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ]) |
|
|
|
|
| def colorize_depth(depth: np.ndarray, max_depth: float = MAX_DEPTH, cmap: str = "turbo_r") -> np.ndarray: |
| """Inverse-depth visualization in a 0–max_depth meter range; invalid/far pixels are white.""" |
| valid = np.isfinite(depth) & (depth > 0) & (depth <= max_depth) |
| if not np.any(valid): |
| return np.full((*depth.shape, 3), 255, dtype=np.uint8) |
|
|
| disp = np.where(valid, 1.0 / depth, np.nan) |
| min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99) |
| disp = (disp - min_disp) / (max_disp - min_disp) if max_disp > min_disp else disp * 0.0 |
|
|
| colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0.0) |
| colored = (colored.clip(0.0, 1.0) * 255).astype(np.uint8) |
| colored[~valid] = 255 |
| return np.ascontiguousarray(colored) |
|
|
|
|
| def prepare_focal(image: Image.Image, image_path: Path | None) -> Tuple[float, str, gr.Slider]: |
| width = image.width |
| fx = None |
| info = "" |
| if image_path is not None: |
| fx = read_intrinsics(image_path.with_suffix(".json")) |
| if fx is not None: |
| info = f"Intrinsics found. Using focal length (pixels): **{fx:.2f}**." |
| else: |
| info = f"No intrinsics found. Using image width (W={width}) as focal length (pixels)." |
| else: |
| info = f"No intrinsics found. Using image width (W={width}) as focal length (pixels)." |
|
|
| if fx is None: |
| fx = float(width) |
|
|
| |
| slider = gr.update(value=fx, minimum=1, maximum=max(2000, int(width * 2)), step=1) |
| return fx, info, slider |
|
|
|
|
| def select_example(example_paths: list[str], evt: gr.SelectData): |
| path = Path(example_paths[evt.index]) |
| image = Image.open(path).convert("RGB") |
| _, info, slider = prepare_focal(image, path) |
| return image, slider, info, "example" |
|
|
|
|
| def on_input_change(image: Image.Image | None, source: str): |
| if image is None: |
| |
| return gr.update(), gr.update(), "" |
| if source == "example": |
| |
| return gr.update(), gr.update(), "" |
| _, info, slider = prepare_focal(image, None) |
| return slider, info, "" |
|
|
|
|
| def load_model() -> MetricAnythingDepthMap: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = MetricAnythingDepthMap.from_pretrained( |
| MODEL_ID, |
| model_kwargs={"device": device}, |
| filename=MODEL_FILENAME, |
| ) |
| model.eval() |
| return model |
|
|
|
|
| TRANSFORM = make_transform() |
| MODEL = load_model() |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| @torch.no_grad() |
| @(spaces.GPU if SPACES_AVAILABLE else (lambda f: f)) |
| def run_inference(image: Image.Image | None, focal_px: float): |
| if image is None: |
| return None, "Please provide an input image." |
|
|
| tensor = TRANSFORM(image).unsqueeze(0).to(DEVICE) |
| prediction = MODEL.infer(tensor, f_px=float(focal_px)) |
| depth = prediction["depth"].detach().cpu().numpy().squeeze() |
|
|
| vis = colorize_depth(depth, max_depth=MAX_DEPTH) |
| valid = np.isfinite(depth) & (depth > 0) & (depth <= MAX_DEPTH) |
| if np.any(valid): |
| min_d = float(depth[valid].min()) |
| max_d = float(depth[valid].max()) |
| stats = f"Depth range (0–{MAX_DEPTH:.0f} m): min={min_d:.2f} m, max={max_d:.2f} m" |
| else: |
| stats = f"No valid depth in 0–{MAX_DEPTH:.0f} m range." |
|
|
| return vis, stats |
|
|
|
|
| def build_demo() -> gr.Blocks: |
| example_paths = list_examples() |
| gallery_items = [(str(p), p.name) for p in example_paths] |
|
|
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# MetricAnything DepthMap") |
| gr.Markdown("Select an example or upload your own image to estimate metric depth.") |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| gallery = gr.Gallery( |
| value=gallery_items, |
| label="Examples", |
| columns=4, |
| rows=2, |
| height=220, |
| ) |
| input_image = gr.Image(type="pil", label="Input", height=320) |
| focal_slider = gr.Slider(label="Focal length (pixels)", minimum=1, maximum=4000, step=1, value=1000) |
| info = gr.Markdown("Select an example or upload an image.") |
| run_btn = gr.Button("Run") |
|
|
| with gr.Column(scale=4): |
| output_image = gr.Image(type="numpy", label="Depth (visualized)", height=420) |
| output_stats = gr.Markdown("") |
|
|
| example_state = gr.State([str(p) for p in example_paths]) |
| source_state = gr.State("") |
|
|
| if example_paths: |
| gallery.select( |
| select_example, |
| inputs=[example_state], |
| outputs=[input_image, focal_slider, info, source_state], |
| ) |
|
|
| input_image.change( |
| on_input_change, |
| inputs=[input_image, source_state], |
| outputs=[focal_slider, info, source_state], |
| ) |
|
|
| run_btn.click( |
| run_inference, |
| inputs=[input_image, focal_slider], |
| outputs=[output_image, output_stats], |
| ) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
| if __name__ == "__main__": |
| demo.launch() |