diff --git a/Depth-Anything-3-anysize/.flake8 b/Depth-Anything-3-anysize/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..57d083009957dd371fb67cee54b80d356b10fe7e --- /dev/null +++ b/Depth-Anything-3-anysize/.flake8 @@ -0,0 +1,3 @@ +[flake8] +max-line-length = 100 +ignore = E203 E741 W503 E731 diff --git a/Depth-Anything-3-anysize/.gitattributes b/Depth-Anything-3-anysize/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Depth-Anything-3-anysize/.gitignore b/Depth-Anything-3-anysize/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..98bfa6d1d52615d3a87d1cac9c61a6f837d58444 --- /dev/null +++ b/Depth-Anything-3-anysize/.gitignore @@ -0,0 +1,36 @@ +# Python cache +__pycache__/ +*.py[cod] + + +# Distribution / packaging +workspace/ +build/ +dist/ +*.egg-info/ +.gradio/ + +# Test/coverage +.coverage +.pytest_cache/ +htmlcov/ +.tox/ +gallery*/ +debug*/ +DA3HF*/ +gradio_workspace/ +eval_workspace/ +FILTER*/ +input_images*/ +*.gradio/ + +# Jupyter notebooks +.ipynb_checkpoints + +# OS files +.DS_Store + +.vscode +src/debug_main.py +temp*.png +/outputs diff --git a/Depth-Anything-3-anysize/.pre-commit-config.yaml b/Depth-Anything-3-anysize/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d9935935aee777bc71316f2e999759f8e612df2 --- /dev/null +++ b/Depth-Anything-3-anysize/.pre-commit-config.yaml @@ -0,0 +1,59 @@ +repos: + - repo: 'https://github.com/pre-commit/pre-commit-hooks' + rev: v4.5.0 + hooks: + - id: check-added-large-files + args: + - '--maxkb=125' + - id: check-ast + - id: check-executables-have-shebangs + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: detect-private-key + - id: end-of-file-fixer + - id: no-commit-to-branch + args: + - '--branch' + - 'master' + - id: pretty-format-json + exclude: '.*\.ipynb$' + args: + - '--autofix' + - '--indent' + - '4' + - id: trailing-whitespace + args: + - '--markdown-linebreak-ext=md' + - repo: 'https://github.com/pycqa/isort' + rev: 5.13.2 + hooks: + - id: isort + args: + - '--settings-file' + - 'pyproject.toml' + - '--filter-files' + - repo: 'https://github.com/asottile/pyupgrade' + rev: v3.15.2 + hooks: + - id: pyupgrade + args: [--py38-plus, --keep-runtime-typing] + - repo: 'https://github.com/psf/black.git' + rev: 24.3.0 + hooks: + - id: black + args: + - '--config=pyproject.toml' + - repo: 'https://github.com/PyCQA/flake8' + rev: 7.0.0 + hooks: + - id: flake8 + args: + - '--config=.flake8' + - repo: 'https://github.com/myint/autoflake' + rev: v1.4 + hooks: + - id: autoflake + args: [ '--remove-all-unused-imports', '--recursive', '--remove-unused-variables', '--in-place'] diff --git a/Depth-Anything-3-anysize/README.md b/Depth-Anything-3-anysize/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1be067d75b704e49d6072288f98a8075d51afb9 --- /dev/null +++ b/Depth-Anything-3-anysize/README.md @@ -0,0 +1,10 @@ +# Depth Anything 3 AnySize + +## πŸ”„ Key Modifications from the [Original Repo](https://github.com/ByteDance-Seed/Depth-Anything-3) +- **Native-Resolution Inputs:** Images are now processed at their original resolution by default. During inference, inputs are padded to the ViT patch size, and outputs (depth/confidence/sky maps and processed images) are cropped back to the source height and width. Using larger inputs now will increase memory and compute requirements. +- **Updated Defaults:** The CLI defaults to `--process-res None --process-res-method keep`, and the API uses `process_res=None, process_res_method="keep"`. See `docs/CLI.md` and `docs/API.md` for details. +- **Optional Downscaling:** For faster inference and lower memory usage, set `process_res` (e.g., `720`) with a resize strategy like `--process-res-method upper_bound_resize`. +- **Original Baseline:** Previously, images were resized to 504 px on the long side. +- **Implementation Details:** Input padding is handled in `src/depth_anything_3/utils/io/input_processor.py`, and output cropping is managed in `src/depth_anything_3/api.py`. + +-------------------------------------- diff --git a/Depth-Anything-3-anysize/app.py b/Depth-Anything-3-anysize/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7688d7ad912130490e4851c78ad5cd7e9c9b6931 --- /dev/null +++ b/Depth-Anything-3-anysize/app.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from typing import Dict, Optional, Tuple + +import gradio as gr +import numpy as np +import torch +from PIL import Image + +from depth_anything_3.api import DepthAnything3 +from depth_anything_3.utils.visualize import visualize_depth + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +MODEL_SOURCES: Dict[str, str] = { + "Depth Anything v3 Nested Giant Large": "depth-anything/DA3NESTED-GIANT-LARGE", + "Depth Anything v3 Giant": "depth-anything/DA3-GIANT", + "Depth Anything v3 Large": "depth-anything/DA3-LARGE", + "Depth Anything v3 Base": "depth-anything/DA3-BASE", + "Depth Anything v3 Small": "depth-anything/DA3-SMALL", + "Depth Anything v3 Metric Large": "depth-anything/DA3METRIC-LARGE", + "Depth Anything v3 Mono Large": "depth-anything/DA3MONO-LARGE", +} +_MODEL_CACHE: Dict[str, DepthAnything3] = {} + + +def _load_model(model_label: str) -> DepthAnything3: + repo_id = MODEL_SOURCES[model_label] + if repo_id not in _MODEL_CACHE: + model = DepthAnything3.from_pretrained(repo_id) + model = model.to(device=DEVICE) + model.eval() + _MODEL_CACHE[repo_id] = model + return _MODEL_CACHE[repo_id] + + +def _prep_image(image: np.ndarray) -> np.ndarray: + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + if image.dtype != np.uint8: + image = np.clip(image, 0, 255).astype(np.uint8) + return image + + +def run_inference( + model_label: str, + image: Optional[np.ndarray], +) -> tuple[Tuple[np.ndarray, np.ndarray], str]: + if image is None: + raise gr.Error("Upload an image before running inference.") + rgb = _prep_image(image) + model = _load_model(model_label) + prediction = model.inference( + image=[Image.fromarray(rgb)], + process_res=None, + process_res_method="keep", + ) + depth_map = prediction.depth[0] + depth_vis = visualize_depth(depth_map, cmap="Spectral") + processed_rgb = ( + prediction.processed_images[0] + if prediction.processed_images is not None + else rgb + ) + slider_value: Tuple[np.ndarray, np.ndarray] = (processed_rgb, depth_vis) + lines = [ + f"**Model:** `{MODEL_SOURCES[model_label]}`", + f"**Device:** `{DEVICE}`", + f"**Depth shape:** `{tuple(prediction.depth.shape)}`", + ] + if prediction.extrinsics is not None: + lines.append(f"**Extrinsics shape:** `{prediction.extrinsics.shape}`") + if prediction.intrinsics is not None: + lines.append(f"**Intrinsics shape:** `{prediction.intrinsics.shape}`") + return slider_value, "\n".join(lines) + + +def build_app() -> gr.Blocks: + with gr.Blocks(title="Depth Anything v3 - Any Size Demo") as demo: + gr.Markdown( + """ + ## Depth Anything v3 (Any-Size Demo) + Upload an image, pick a pretrained model, and compare RGB against the inferred depth. + """ + ) + with gr.Row(): + model_dropdown = gr.Dropdown( + choices=list(MODEL_SOURCES.keys()), + value="Depth Anything v3 Large", + label="Model", + ) + image_input = gr.Image(type="numpy", label="Input Image", image_mode="RGB") + run_button = gr.Button("Run Inference", variant="primary") + with gr.Row(): + comparison_slider = gr.ImageSlider(label="RGB vs Depth") + info_panel = gr.Markdown() + run_button.click( + fn=run_inference, + inputs=[model_dropdown, image_input], + outputs=[comparison_slider, info_panel], + ) + return demo + + +def main() -> None: + app = build_app() + app.queue(max_size=8).launch() + + +if __name__ == "__main__": + main() diff --git a/Depth-Anything-3-anysize/depth3_anysize.py b/Depth-Anything-3-anysize/depth3_anysize.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8ff3b7798b703daf201ffa579135e2fe9c9137 --- /dev/null +++ b/Depth-Anything-3-anysize/depth3_anysize.py @@ -0,0 +1,52 @@ +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import torch +from depth_anything_3.api import DepthAnything3 +from depth_anything_3.utils.visualize import visualize_depth + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = DepthAnything3.from_pretrained("depth-anything/DA3-LARGE") +model = model.to(device) +model.eval() +print(f"Model loaded on {device}") + +# Load sample images and run inference +image_paths = [ + "assets/examples/SOH/demo.png", +] + +# Run inference +prediction = model.inference( + image=image_paths, + # export_dir=None, + # export_format="glb" +) +print(f"Depth shape: {prediction.depth.shape}") +print(f"Extrinsics: {prediction.extrinsics.shape if prediction.extrinsics is not None else 'None'}") +print(f"Intrinsics: {prediction.intrinsics.shape if prediction.intrinsics is not None else 'None'}") + +# Visualize input images and depth maps +n_images = prediction.depth.shape[0] + +fig, axes = plt.subplots(2, n_images, figsize=(12, 6)) + +if n_images == 1: + axes = axes.reshape(2, 1) + +for i in range(n_images): + # Show original image + if prediction.processed_images is not None: + axes[0, i].imshow(prediction.processed_images[i]) + axes[0, i].set_title(f"Input {i+1}") + axes[0, i].axis('off') + + # Show depth map + depth_vis = visualize_depth(prediction.depth[i], cmap="Spectral") + axes[1, i].imshow(depth_vis) + axes[1, i].set_title(f"Depth {i+1}") + axes[1, i].axis('off') + +plt.tight_layout() +plt.show() diff --git a/Depth-Anything-3-anysize/pyproject.toml b/Depth-Anything-3-anysize/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..baf889772078c6f90039020fefc2fa94cd7cf877 --- /dev/null +++ b/Depth-Anything-3-anysize/pyproject.toml @@ -0,0 +1,93 @@ +[build-system] +requires = ["hatchling>=1.25", "hatch-vcs>=0.4"] +build-backend = "hatchling.build" + +[project] +name = "depth-anything-3" +version = "0.0.0" +description = "Depth Anything 3" +readme = "README.md" +requires-python = ">=3.9, <=3.13" +license = { text = "Apache-2.0" } +authors = [{ name = "Your Name" }] + +dependencies = [ + "pre-commit", + "trimesh", + "torch>=2", + "torchvision", + "einops", + "huggingface_hub", + "imageio", + "numpy<2", + "opencv-python", + "open3d", + "fastapi", + "uvicorn", + "requests", + "typer", + "pillow", + "omegaconf", + "evo", + "e3nn", + "moviepy", + "plyfile", + "pillow_heif", + "safetensors", + "uvicorn", + "moviepy==1.0.3", + "typer>=0.9.0", + "pycolmap", +] + +[project.optional-dependencies] +app = ["gradio>=5", "pillow>=9.0"] # requires that python3>=3.10 +gs = ["gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"] +all = ["depth-anything-3[app,gs]"] + + +[project.scripts] +da3 = "depth_anything_3.cli:app" + +[project.urls] +Homepage = "https://github.com/ByteDance-Seed/Depth-Anything-3" + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.build.targets.wheel] +packages = ["src/depth_anything_3"] + +[tool.hatch.build.targets.sdist] +include = [ + "README.md", + "pyproject.toml", + "src/depth_anything_3", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.mypy] +plugins = ["jaxtyping.mypy_plugin"] + +[tool.black] +line-length = 99 +target-version = ['py37', 'py38', 'py39', 'py310', 'py311'] +include = '\.pyi?$' +exclude = ''' +/( + | \.git +)/ +''' + +[tool.isort] +profile = "black" +multi_line_output = 3 +include_trailing_comma = true +known_third_party = ["bson","cruise","cv2","dataloader","diffusers","omegaconf","tensorflow","torch","torchvision","transformers","gsplat"] +known_first_party = ["common", "data", "models", "projects"] +sections = ["FUTURE","STDLIB","THIRDPARTY","FIRSTPARTY","LOCALFOLDER"] +skip_gitignore = true +line_length = 99 +no_lines_before="THIRDPARTY" diff --git a/Depth-Anything-3-anysize/requirements.txt b/Depth-Anything-3-anysize/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4aed1baf1811fcdb9ec2055f8cedf02663efc4b6 --- /dev/null +++ b/Depth-Anything-3-anysize/requirements.txt @@ -0,0 +1,24 @@ +torchvision +einops +huggingface_hub +imageio +opencv-python +open3d +fastapi +requests +evo +e3nn +moviepy==1.0.3 +plyfile +pillow_heif +safetensors +pycolmap +torch>=2 +uvicorn +typer>=0.9.0 +pillow +pre-commit +trimesh +numpy<2 +omegaconf +-e .[all] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/api.py b/Depth-Anything-3-anysize/src/depth_anything_3/api.py new file mode 100644 index 0000000000000000000000000000000000000000..7b269f09dda0710ed333130561f5cda8bd918bcb --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/api.py @@ -0,0 +1,483 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Depth Anything 3 API module. + +This module provides the main API for Depth Anything 3, including model loading, +inference, and export capabilities. It supports both single and nested model architectures. +""" + +from __future__ import annotations + +import time +from typing import Optional, Sequence +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin +from PIL import Image + +from depth_anything_3.cfg import create_object, load_config +from depth_anything_3.registry import MODEL_REGISTRY +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.export import export +from depth_anything_3.utils.geometry import affine_inverse +from depth_anything_3.utils.io.input_processor import InputProcessor +from depth_anything_3.utils.io.output_processor import OutputProcessor +from depth_anything_3.utils.logger import logger +from depth_anything_3.utils.pose_align import align_poses_umeyama + +torch.backends.cudnn.benchmark = False +# logger.info("CUDNN Benchmark Disabled") + +SAFETENSORS_NAME = "model.safetensors" +CONFIG_NAME = "config.json" + + +class DepthAnything3(nn.Module, PyTorchModelHubMixin): + """ + Depth Anything 3 main API class. + + This class provides a high-level interface for depth estimation using Depth Anything 3. + It supports both single and nested model architectures with metric scaling capabilities. + + Features: + - Hugging Face Hub integration via PyTorchModelHubMixin + - Support for multiple model presets (vitb, vitg, nested variants) + - Automatic mixed precision inference + - Export capabilities for various formats (GLB, PLY, NPZ, etc.) + - Camera pose estimation and metric depth scaling + + Usage: + # Load from Hugging Face Hub + model = DepthAnything3.from_pretrained("huggingface/model-name") + + # Or create with specific preset + model = DepthAnything3(preset="vitg") + + # Run inference + prediction = model.inference(images, export_dir="output", export_format="glb") + """ + + _commit_hash: str | None = None # Set by mixin when loading from Hub + + def __init__(self, model_name: str = "da3-large", **kwargs): + """ + Initialize DepthAnything3 with specified preset. + + Args: + model_name: The name of the model preset to use. + Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'. + **kwargs: Additional keyword arguments (currently unused). + """ + super().__init__() + self.model_name = model_name + + # Build the underlying network + self.config = load_config(MODEL_REGISTRY[self.model_name]) + self.model = create_object(self.config) + self.model.eval() + + # Initialize processors + self.input_processor = InputProcessor() + self.output_processor = OutputProcessor() + + # Device management (set by user) + self.device = None + + @torch.inference_mode() + def forward( + self, + image: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = None, + infer_gs: bool = False, + ) -> dict[str, torch.Tensor]: + """ + Forward pass through the model. + + Args: + image: Input batch with shape ``(B, N, 3, H, W)`` on the model device. + extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``. + intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``. + export_feat_layers: Layer indices to return intermediate features for. + + Returns: + Dictionary containing model predictions + """ + # Determine optimal autocast dtype + autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + with torch.no_grad(): + with torch.autocast(device_type=image.device.type, dtype=autocast_dtype): + return self.model(image, extrinsics, intrinsics, export_feat_layers, infer_gs) + + def inference( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + align_to_input_ext_scale: bool = True, + infer_gs: bool = False, + render_exts: np.ndarray | None = None, + render_ixts: np.ndarray | None = None, + render_hw: tuple[int, int] | None = None, + process_res: int | None = None, + process_res_method: str = "keep", + export_dir: str | None = None, + export_format: str = "mini_npz", + export_feat_layers: Sequence[int] | None = None, + # GLB export parameters + conf_thresh_percentile: float = 40.0, + num_max_points: int = 1_000_000, + show_cameras: bool = True, + # Feat_vis export parameters + feat_vis_fps: int = 15, + # Other export parameters, e.g., gs_ply, gs_video + export_kwargs: Optional[dict] = {}, + ) -> Prediction: + """ + Run inference on input images. + + Args: + image: List of input images (numpy arrays, PIL Images, or file paths) + extrinsics: Camera extrinsics (N, 4, 4) + intrinsics: Camera intrinsics (N, 3, 3) + align_to_input_ext_scale: whether to align the input pose scale to the prediction + infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports) + render_exts: Optional render extrinsics for Gaussian video export + render_ixts: Optional render intrinsics for Gaussian video export + render_hw: Optional render resolution for Gaussian video export + process_res: Processing resolution + process_res_method: Resize method for processing + export_dir: Directory to export results + export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video) + export_feat_layers: Layer indices to export intermediate features from + conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501 + num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000) + show_cameras: [GLB] Show camera wireframes in the exported scene (default: True) + feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15) + export_kwargs: additional arguments to export functions. + + Returns: + Prediction object containing depth maps and camera parameters + """ + if "gs" in export_format: + assert infer_gs, "must set `infer_gs=True` to perform gs-related export." + + if "colmap" in export_format: + assert isinstance(image[0], str), "`image` must be image paths for COLMAP export." + + # Preprocess images + imgs_cpu, extrinsics, intrinsics, pad_meta = self._preprocess_inputs( + image, extrinsics, intrinsics, process_res, process_res_method + ) + + # Prepare tensors for model + imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics) + + # Normalize extrinsics + ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None) + + # Run model forward pass + export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else [] + + raw_output = self._run_model_forward(imgs, ex_t_norm, in_t, export_feat_layers, infer_gs) + + # Convert raw output to prediction + prediction = self._convert_to_prediction(raw_output) + + # Crop padded regions back to original sizes if needed + prediction = self._crop_to_original(prediction, pad_meta) + + # Align prediction to extrinsincs + prediction = self._align_to_input_extrinsics_intrinsics( + extrinsics, intrinsics, prediction, align_to_input_ext_scale + ) + + # Add processed images for visualization + prediction = self._add_processed_images(prediction, imgs_cpu, pad_meta) + + # Export if requested + if export_dir is not None: + + if "gs" in export_format: + if infer_gs and "gs_video" not in export_format: + export_format = f"{export_format}-gs_video" + if "gs_video" in export_format: + if "gs_video" not in export_kwargs: + export_kwargs["gs_video"] = {} + export_kwargs["gs_video"].update( + { + "extrinsics": render_exts, + "intrinsics": render_ixts, + "out_image_hw": render_hw, + } + ) + # Add GLB export parameters + if "glb" in export_format: + if "glb" not in export_kwargs: + export_kwargs["glb"] = {} + export_kwargs["glb"].update( + { + "conf_thresh_percentile": conf_thresh_percentile, + "num_max_points": num_max_points, + "show_cameras": show_cameras, + } + ) + # Add Feat_vis export parameters + if "feat_vis" in export_format: + if "feat_vis" not in export_kwargs: + export_kwargs["feat_vis"] = {} + export_kwargs["feat_vis"].update( + { + "fps": feat_vis_fps, + } + ) + # Add COLMAP export parameters + if "colmap" in export_format: + if "colmap" not in export_kwargs: + export_kwargs["colmap"] = {} + export_kwargs["colmap"].update( + { + "image_paths": image, + "conf_thresh_percentile": conf_thresh_percentile, + "process_res_method": process_res_method, + } + ) + self._export_results(prediction, export_format, export_dir, **export_kwargs) + + return prediction + + def _preprocess_inputs( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + process_res: int | None = None, + process_res_method: str = "keep", + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[dict]]: + """Preprocess input images using input processor.""" + start_time = time.time() + imgs_cpu, extrinsics, intrinsics, pad_meta = self.input_processor( + image, + extrinsics.copy() if extrinsics is not None else None, + intrinsics.copy() if intrinsics is not None else None, + process_res, + process_res_method, + ) + end_time = time.time() + logger.info( + "Processed Images Done taking", + end_time - start_time, + "seconds. Shape: ", + imgs_cpu.shape, + ) + return imgs_cpu, extrinsics, intrinsics, pad_meta + + def _prepare_model_inputs( + self, + imgs_cpu: torch.Tensor, + extrinsics: torch.Tensor | None, + intrinsics: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Prepare tensors for model input.""" + device = self._get_model_device() + + # Move images to model device + imgs = imgs_cpu.to(device, non_blocking=True)[None].float() + + # Convert camera parameters to tensors + ex_t = ( + extrinsics.to(device, non_blocking=True)[None].float() + if extrinsics is not None + else None + ) + in_t = ( + intrinsics.to(device, non_blocking=True)[None].float() + if intrinsics is not None + else None + ) + + return imgs, ex_t, in_t + + def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None: + """Normalize extrinsics""" + if ex_t is None: + return None + transform = affine_inverse(ex_t[:, :1]) + ex_t_norm = ex_t @ transform + c2ws = affine_inverse(ex_t_norm) + translations = c2ws[..., :3, 3] + dists = translations.norm(dim=-1) + median_dist = torch.median(dists) + median_dist = torch.clamp(median_dist, min=1e-1) + ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist + return ex_t_norm + + def _align_to_input_extrinsics_intrinsics( + self, + extrinsics: torch.Tensor | None, + intrinsics: torch.Tensor | None, + prediction: Prediction, + align_to_input_ext_scale: bool = True, + ransac_view_thresh: int = 10, + ) -> Prediction: + """Align depth map to input extrinsics""" + if extrinsics is None: + return prediction + prediction.intrinsics = intrinsics.numpy() + _, _, scale, aligned_extrinsics = align_poses_umeyama( + prediction.extrinsics, + extrinsics.numpy(), + ransac=len(extrinsics) >= ransac_view_thresh, + return_aligned=True, + random_state=42, + ) + if align_to_input_ext_scale: + prediction.extrinsics = extrinsics[..., :3, :].numpy() + prediction.depth /= scale + else: + prediction.extrinsics = aligned_extrinsics + return prediction + + def _run_model_forward( + self, + imgs: torch.Tensor, + ex_t: torch.Tensor | None, + in_t: torch.Tensor | None, + export_feat_layers: Sequence[int] | None = None, + infer_gs: bool = False, + ) -> dict[str, torch.Tensor]: + """Run model forward pass.""" + device = imgs.device + need_sync = device.type == "cuda" + if need_sync: + torch.cuda.synchronize(device) + start_time = time.time() + feat_layers = list(export_feat_layers) if export_feat_layers is not None else None + output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs) + if need_sync: + torch.cuda.synchronize(device) + end_time = time.time() + logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds") + return output + + def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction: + """Convert raw model output to Prediction object.""" + start_time = time.time() + output = self.output_processor(raw_output) + end_time = time.time() + logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds") + return output + + def _add_processed_images( + self, prediction: Prediction, imgs_cpu: torch.Tensor, pad_meta: list[dict] + ) -> Prediction: + """Add processed images to prediction for visualization.""" + # Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize + processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3) + + # Denormalize from ImageNet normalization + mean = np.array([0.485, 0.456, 0.406]) + std = np.array([0.229, 0.224, 0.225]) + processed_imgs = processed_imgs * std + mean + processed_imgs = np.clip(processed_imgs, 0, 1) + processed_imgs = (processed_imgs * 255).astype(np.uint8) + + # Crop to original size if padding was applied + if pad_meta: + cropped_imgs = [] + for i, meta in enumerate(pad_meta): + img = processed_imgs[i] + pt, pb, pl, pr = meta.get("pad", (0, 0, 0, 0)) + if any((pt, pb, pl, pr)): + img = img[pt : img.shape[0] - pb if pb > 0 else img.shape[0], pl : img.shape[1] - pr if pr > 0 else img.shape[1]] + cropped_imgs.append(img) + processed_imgs = np.stack(cropped_imgs, axis=0) + + prediction.processed_images = processed_imgs + return prediction + + def _export_results( + self, prediction: Prediction, export_format: str, export_dir: str, **kwargs + ) -> None: + """Export results to specified format and directory.""" + start_time = time.time() + export(prediction, export_format, export_dir, **kwargs) + end_time = time.time() + logger.info(f"Export Results Done. Time: {end_time - start_time} seconds") + + def _get_model_device(self) -> torch.device: + """ + Get the device where the model is located. + + Returns: + Device where the model parameters are located + + Raises: + ValueError: If no tensors are found in the model + """ + if self.device is not None: + return self.device + + # Find device from parameters + for param in self.parameters(): + self.device = param.device + return param.device + + # Find device from buffers + for buffer in self.buffers(): + self.device = buffer.device + return buffer.device + + raise ValueError("No tensor found in model") + + def _crop_to_original(self, prediction: Prediction, pad_meta: list[dict]) -> Prediction: + """ + Remove padding added for patch divisibility to restore original HxW. + """ + if not pad_meta: + return prediction + depth_list = [] + conf_list = [] if prediction.conf is not None else None + sky_list = [] if prediction.sky is not None else None + + for idx, meta in enumerate(pad_meta): + pt, pb, pl, pr = meta.get("pad", (0, 0, 0, 0)) + + def crop(arr: np.ndarray | None) -> np.ndarray | None: + if arr is None: + return None + h, w = arr.shape[-2], arr.shape[-1] + return arr[pt : h - pb if pb > 0 else h, pl : w - pr if pr > 0 else w] + + depth_list.append(crop(prediction.depth[idx]) if prediction.depth is not None else None) + if conf_list is not None: + conf_list.append(crop(prediction.conf[idx])) + if sky_list is not None: + sky_list.append(crop(prediction.sky[idx])) + + if prediction.intrinsics is not None: + prediction.intrinsics[idx, 0, 2] -= pl + prediction.intrinsics[idx, 1, 2] -= pt + + if depth_list: + prediction.depth = np.stack(depth_list, axis=0) + if conf_list is not None: + prediction.conf = np.stack(conf_list, axis=0) + if sky_list is not None: + prediction.sky = np.stack(sky_list, axis=0) + + return prediction diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/css_and_html.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/css_and_html.py new file mode 100644 index 0000000000000000000000000000000000000000..d414df9db72a4a395bc15d87ff4edbeb213f18bb --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/css_and_html.py @@ -0,0 +1,594 @@ +# flake8: noqa: E501 + +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CSS and HTML content for the Depth Anything 3 Gradio application. +This module contains all the CSS styles and HTML content blocks +used in the Gradio interface. +""" + +# CSS Styles for the Gradio interface +GRADIO_CSS = """ +/* Add Font Awesome CDN with all styles including brands and colors */ +@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css'); + +/* Add custom styles for colored icons */ +.fa-color-blue { + color: #3b82f6; +} + +.fa-color-purple { + color: #8b5cf6; +} + +.fa-color-cyan { + color: #06b6d4; +} + +.fa-color-green { + color: #10b981; +} + +.fa-color-yellow { + color: #f59e0b; +} + +.fa-color-red { + color: #ef4444; +} + +.link-btn { + display: inline-flex; + align-items: center; + gap: 8px; + text-decoration: none; + padding: 12px 24px; + border-radius: 50px; + font-weight: 500; + transition: all 0.3s ease; +} + +/* Dark mode tech theme */ +@media (prefers-color-scheme: dark) { + html, body { + background: #1e293b; + color: #ffffff; + } + + .gradio-container { + background: #1e293b; + color: #ffffff; + } + + .link-btn { + background: rgba(255, 255, 255, 0.2); + color: white; + backdrop-filter: blur(10px); + border: 1px solid rgba(255, 255, 255, 0.3); + } + + .link-btn:hover { + background: rgba(255, 255, 255, 0.3); + transform: translateY(-2px); + box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2); + } + + .tech-bg { + background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */ + position: relative; + overflow: hidden; + } + + .tech-bg::before { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: + radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */ + radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */ + radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */ + animation: techPulse 8s ease-in-out infinite; + } + + .gradio-container .panel, + .gradio-container .block, + .gradio-container .form { + background: rgba(0, 0, 0, 0.3); + border: 1px solid rgba(59, 130, 246, 0.2); + border-radius: 10px; + } + + .gradio-container * { + color: #ffffff; + } + + .gradio-container label { + color: #e0e0e0; + } + + .gradio-container .markdown { + color: #e0e0e0; + } +} + +/* Light mode tech theme */ +@media (prefers-color-scheme: light) { + html, body { + background: #ffffff; + color: #1e293b; + } + + .gradio-container { + background: #ffffff; + color: #1e293b; + } + + .tech-bg { + background: linear-gradient(135deg, #ffffff, #f1f5f9); + position: relative; + overflow: hidden; + } + + .link-btn { + background: rgba(59, 130, 246, 0.15); + color: var(--body-text-color); + border: 1px solid rgba(59, 130, 246, 0.3); + } + + .link-btn:hover { + background: rgba(59, 130, 246, 0.25); + transform: translateY(-2px); + box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2); + } + + .tech-bg::before { + content: ''; + position: absolute; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: + radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%), + radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%), + radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%); + animation: techPulse 8s ease-in-out infinite; + } + + .gradio-container .panel, + .gradio-container .block, + .gradio-container .form { + background: rgba(255, 255, 255, 0.8); + border: 1px solid rgba(59, 130, 246, 0.3); + border-radius: 10px; + box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); + } + + .gradio-container * { + color: #1e293b; + } + + .gradio-container label { + color: #334155; + } + + .gradio-container .markdown { + color: #334155; + } +} + + + + +@keyframes techPulse { + 0%, 100% { opacity: 0.5; } + 50% { opacity: 0.8; } +} + +/* Custom log with tech gradient */ +.custom-log * { + font-style: italic; + font-size: 22px !important; + background: linear-gradient(135deg, #3b82f6, #8b5cf6); + background-size: 400% 400%; + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; + animation: techGradient 3s ease infinite; +} + +@keyframes techGradient { + 0% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } + 100% { background-position: 0% 50%; } +} + +@keyframes metricPulse { + 0%, 100% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } +} + +@keyframes pointcloudPulse { + 0%, 100% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } +} + +@keyframes camerasPulse { + 0%, 100% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } +} + +@keyframes gaussiansPulse { + 0%, 100% { background-position: 0% 50%; } + 50% { background-position: 100% 50%; } +} + +/* Special colors for key terms - Global styles */ +.metric-text { + background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b); + background-size: 200% 200%; + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + animation: metricPulse 2s ease-in-out infinite; + font-weight: 700; + text-shadow: 0 0 10px rgba(255, 107, 107, 0.5); +} + +.pointcloud-text { + background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4); + background-size: 200% 200%; + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + animation: pointcloudPulse 2.5s ease-in-out infinite; + font-weight: 700; + text-shadow: 0 0 10px rgba(78, 205, 196, 0.5); +} + +.cameras-text { + background: linear-gradient(45deg, #667eea, #764ba2, #667eea); + background-size: 200% 200%; + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + animation: camerasPulse 3s ease-in-out infinite; + font-weight: 700; + text-shadow: 0 0 10px rgba(102, 126, 234, 0.5); +} + +.gaussians-text { + background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb); + background-size: 200% 200%; + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + animation: gaussiansPulse 2.2s ease-in-out infinite; + font-weight: 700; + text-shadow: 0 0 10px rgba(240, 147, 251, 0.5); +} + +.example-log * { + font-style: italic; + font-size: 16px !important; + background: linear-gradient(135deg, #3b82f6, #8b5cf6); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; +} + +#my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; +} + +#my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; +} + +/* Align navigation buttons with dropdown bottom */ +.navigation-row { + display: flex !important; + align-items: flex-end !important; + gap: 8px !important; +} + +.navigation-row > div:nth-child(1), +.navigation-row > div:nth-child(3) { + align-self: flex-end !important; +} + +.navigation-row > div:nth-child(2) { + flex: 1 !important; +} + +/* Make thumbnails clickable with pointer cursor */ +.clickable-thumbnail img { + cursor: pointer !important; +} + +.clickable-thumbnail:hover img { + cursor: pointer !important; + opacity: 0.8; + transition: opacity 0.3s ease; +} + +/* Make thumbnail containers narrower horizontally */ +.clickable-thumbnail { + padding: 5px 2px !important; + margin: 0 2px !important; +} + +.clickable-thumbnail .image-container { + margin: 0 !important; + padding: 0 !important; +} + +.scene-info { + text-align: center !important; + padding: 5px 2px !important; + margin: 0 !important; +} +""" + + +def get_header_html(logo_base64=None): + """ + Generate the main header HTML with logo and title. + + Args: + logo_base64 (str, optional): Base64 encoded logo image + + Returns: + str: HTML string for the header + """ + return """ +
+
+

+ Depth Anything 3 +

+

+ Recovering the Visual Space from Any Views +

+
+ + + Project Page + + + Paper + + + Code + +
+
+
+ + + """ + + +def get_description_html(): + """ + Generate the main description and getting started HTML. + + Returns: + str: HTML string for the description + """ + return """ +
+

+ What This Demo Does +

+
+

+ Upload images or videos β†’ Get Metric Point Clouds, Cameras and Novel Views β†’ Explore in 3D +

+
+ +
+

+ Tip: Landscape-oriented images or videos are preferred for best 3D recovering. +

+
+
+ + + """ + + +def get_acknowledgements_html(): + """ + Generate the acknowledgements section HTML. + + Returns: + str: HTML string for the acknowledgements + """ + return """ +
+

+ Research Credits & Acknowledgments +

+ +
+ +
+

Original Research

+

+ + Depth Anything 3 + +

+
+ + +
+

Previous Versions

+
+

+ + Depth-Anything + +

+ β€’ +

+ + Depth-Anything-V2 + +

+
+
+
+ + +
+

+ HF demo adapted from Map Anything +

+
+
+ """ + + +def get_gradio_theme(): + """ + Get the configured Gradio theme with adaptive tech colors. + + Returns: + gr.themes.Base: Configured Gradio theme + """ + import gradio as gr + + return gr.themes.Base( + primary_hue=gr.themes.Color( + c50="#eff6ff", + c100="#dbeafe", + c200="#bfdbfe", + c300="#93c5fd", + c400="#60a5fa", + c500="#3b82f6", + c600="#2563eb", + c700="#1d4ed8", + c800="#1e40af", + c900="#1e3a8a", + c950="#172554", + ), + secondary_hue=gr.themes.Color( + c50="#f5f3ff", + c100="#ede9fe", + c200="#ddd6fe", + c300="#c4b5fd", + c400="#a78bfa", + c500="#8b5cf6", + c600="#7c3aed", + c700="#6d28d9", + c800="#5b21b6", + c900="#4c1d95", + c950="#2e1065", + ), + neutral_hue=gr.themes.Color( + c50="#f8fafc", + c100="#f1f5f9", + c200="#e2e8f0", + c300="#cbd5e1", + c400="#94a3b8", + c500="#64748b", + c600="#475569", + c700="#334155", + c800="#1e293b", + c900="#0f172a", + c950="#020617", + ), + ) + + +# Measure tab instructions HTML +MEASURE_INSTRUCTIONS_HTML = """ +### Click points on the image to compute distance. +> Metric scale estimation is difficult on aerial/drone images. +""" diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/gradio_app.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..daf95237c6d48939834d2e6b7b3a258d0ffe4b5a --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/gradio_app.py @@ -0,0 +1,747 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Refactored Gradio App for Depth Anything 3. + +This is the main application file that orchestrates all components. +The original functionality has been split into modular components for better maintainability. +""" + +import argparse +import os +from typing import Any, Dict, List +import gradio as gr + +from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme +from depth_anything_3.app.modules.event_handlers import EventHandlers +from depth_anything_3.app.modules.ui_components import UIComponents + +# Set environment variables +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + + +class DepthAnything3App: + """ + Main application class for Depth Anything 3 Gradio app. + """ + + def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None): + """ + Initialize the application. + + Args: + model_dir: Path to the model directory + workspace_dir: Path to the workspace directory + gallery_dir: Path to the gallery directory + """ + self.model_dir = model_dir + self.workspace_dir = workspace_dir + self.gallery_dir = gallery_dir + + # Set environment variables for directories + if self.model_dir: + os.environ["DA3_MODEL_DIR"] = self.model_dir + if self.workspace_dir: + os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir + if self.gallery_dir: + os.environ["DA3_GALLERY_DIR"] = self.gallery_dir + + self.event_handlers = EventHandlers() + self.ui_components = UIComponents() + + def cache_examples( + self, + show_cam: bool = True, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + save_percentage: float = 20.0, + num_max_points: int = 1000, + cache_gs_tag: str = "", + gs_trj_mode: str = "smooth", + gs_video_quality: str = "low", + ) -> None: + """ + Pre-cache all example scenes at startup. + + Args: + show_cam: Whether to show camera in visualization + filter_black_bg: Whether to filter black background + filter_white_bg: Whether to filter white background + save_percentage: Filter percentage for point cloud + num_max_points: Maximum number of points + cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv") + gs_trj_mode: Trajectory mode for 3DGS + gs_video_quality: Video quality for 3DGS + """ + from depth_anything_3.app.modules.utils import get_scene_info + + examples_dir = os.path.join(self.workspace_dir, "examples") + if not os.path.exists(examples_dir): + print(f"Examples directory not found: {examples_dir}") + return + + scenes = get_scene_info(examples_dir) + if not scenes: + print("No example scenes found to cache.") + return + + print(f"\n{'='*60}") + print(f"Caching {len(scenes)} example scenes...") + print(f"{'='*60}\n") + + for i, scene in enumerate(scenes, 1): + scene_name = scene["name"] + + # Check if scene name matches the gs tag for high-res+3DGS caching + use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower() + + if use_high_res_gs: + print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)") + print(f" - Number of images: {scene['num_images']}") + print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS") + else: + print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)") + print(f" - Number of images: {scene['num_images']}") + + try: + # Load example scene + _, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene( + scene_name + ) + + if target_dir and target_dir != "None": + # Run reconstruction with appropriate settings + print(" - Running reconstruction...") + result = self.event_handlers.gradio_demo( + target_dir=target_dir, + show_cam=show_cam, + filter_black_bg=filter_black_bg, + filter_white_bg=filter_white_bg, + process_res_method="high_res" if use_high_res_gs else "low_res", + selected_first_frame="", + save_percentage=save_percentage, + num_max_points=num_max_points, + infer_gs=use_high_res_gs, + gs_trj_mode=gs_trj_mode, + gs_video_quality=gs_video_quality, + ) + + # Check if successful + if result[0] is not None: # reconstruction_output + print(f" βœ“ Scene '{scene_name}' cached successfully") + else: + print(f" βœ— Scene '{scene_name}' caching failed: {result[1]}") + else: + print(f" βœ— Scene '{scene_name}' loading failed") + + except Exception as e: + print(f" βœ— Error caching scene '{scene_name}': {str(e)}") + + print() + + print("=" * 60) + print("Example scene caching completed!") + print("=" * 60 + "\n") + + def create_app(self) -> gr.Blocks: + """ + Create and configure the Gradio application. + + Returns: + Configured Gradio Blocks interface + """ + + # Initialize theme + def get_theme(): + return get_gradio_theme() + + with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo: + # State variables for the tabbed interface + is_example = gr.Textbox(label="is_example", visible=False, value="None") + processed_data_state = gr.State(value=None) + measure_points_state = gr.State(value=[]) + selected_first_frame_state = gr.State(value="") + selected_image_index_state = gr.State(value=0) # Track selected image index + # current_view_index = gr.State(value=0) # noqa: F841 Track current view index + + # Header and description + self.ui_components.create_header_section() + self.ui_components.create_description_section() + + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + # Main content area + with gr.Row(): + with gr.Column(scale=2): + # Upload section + ( + input_video, + s_time_interval, + input_images, + image_gallery, + select_first_frame_btn, + ) = self.ui_components.create_upload_section() + + with gr.Column(scale=4): + with gr.Column(): + # gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**") + # Reconstruction control section (buttons) - moved below tabs + + log_output = gr.Markdown( + "Please upload a video or images, then click Reconstruct.", + elem_classes=["custom-log"], + ) + + # Tabbed interface + with gr.Tabs(): + with gr.Tab("Point Cloud & Cameras"): + reconstruction_output = ( + self.ui_components.create_3d_viewer_section() + ) + + with gr.Tab("Metric Depth"): + ( + prev_measure_btn, + measure_view_selector, + next_measure_btn, + measure_image, + measure_depth_image, + measure_text, + ) = self.ui_components.create_measure_section() + + with gr.Tab("3DGS Rendered Novel Views"): + gs_video, gs_info = self.ui_components.create_nvs_video() + + # Inference control section (before inference) + (process_res_method_dropdown, infer_gs) = ( + self.ui_components.create_inference_control_section() + ) + + # Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501 + ( + show_cam, + filter_black_bg, + filter_white_bg, + save_percentage, + num_max_points, + gs_trj_mode, + gs_video_quality, + submit_btn, + clear_btn, + ) = self.ui_components.create_display_control_section() + + # bind visibility of gs_trj_mode to infer_gs + infer_gs.change( + fn=lambda checked: ( + gr.update(visible=checked), + gr.update(visible=checked), + gr.update(visible=checked), + gr.update(visible=(not checked)), + ), + inputs=infer_gs, + outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info], + ) + + # Example scenes section + gr.Markdown("## Example Scenes") + + scenes = self.ui_components.create_example_scenes_section() + scene_components = self.ui_components.create_example_scene_grid(scenes) + + # Set up event handlers + self._setup_event_handlers( + demo, + is_example, + processed_data_state, + measure_points_state, + target_dir_output, + input_video, + input_images, + s_time_interval, + image_gallery, + reconstruction_output, + log_output, + show_cam, + filter_black_bg, + filter_white_bg, + process_res_method_dropdown, + save_percentage, + submit_btn, + clear_btn, + num_max_points, + infer_gs, + select_first_frame_btn, + selected_first_frame_state, + selected_image_index_state, + measure_view_selector, + measure_image, + measure_depth_image, + measure_text, + prev_measure_btn, + next_measure_btn, + scenes, + scene_components, + gs_video, + gs_info, + gs_trj_mode, + gs_video_quality, + ) + + # Acknowledgements + self.ui_components.create_acknowledgements_section() + + return demo + + def _setup_event_handlers( + self, + demo: gr.Blocks, + is_example: gr.Textbox, + processed_data_state: gr.State, + measure_points_state: gr.State, + target_dir_output: gr.Textbox, + input_video: gr.Video, + input_images: gr.File, + s_time_interval: gr.Slider, + image_gallery: gr.Gallery, + reconstruction_output: gr.Model3D, + log_output: gr.Markdown, + show_cam: gr.Checkbox, + filter_black_bg: gr.Checkbox, + filter_white_bg: gr.Checkbox, + process_res_method_dropdown: gr.Dropdown, + save_percentage: gr.Slider, + submit_btn: gr.Button, + clear_btn: gr.ClearButton, + num_max_points: gr.Slider, + infer_gs: gr.Checkbox, + select_first_frame_btn: gr.Button, + selected_first_frame_state: gr.State, + selected_image_index_state: gr.State, + measure_view_selector: gr.Dropdown, + measure_image: gr.Image, + measure_depth_image: gr.Image, + measure_text: gr.Markdown, + prev_measure_btn: gr.Button, + next_measure_btn: gr.Button, + scenes: List[Dict[str, Any]], + scene_components: List[gr.Image], + gs_video: gr.Video, + gs_info: gr.Markdown, + gs_trj_mode: gr.Dropdown, + gs_video_quality: gr.Dropdown, + ) -> None: + """ + Set up all event handlers for the application. + + Args: + demo: Gradio Blocks interface + All other arguments: Gradio components to connect + """ + # Configure clear button + clear_btn.add( + [ + input_video, + input_images, + reconstruction_output, + log_output, + target_dir_output, + image_gallery, + gs_video, + ] + ) + + # Main reconstruction button + submit_btn.click( + fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output] + ).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then( + fn=self.event_handlers.gradio_demo, + inputs=[ + target_dir_output, + show_cam, + filter_black_bg, + filter_white_bg, + process_res_method_dropdown, + selected_first_frame_state, + save_percentage, + # pass num_max_points + num_max_points, + infer_gs, + gs_trj_mode, + gs_video_quality, + ], + outputs=[ + reconstruction_output, + log_output, + processed_data_state, + measure_image, + measure_depth_image, + measure_text, + measure_view_selector, + gs_video, + gs_video, # gs_video visibility + gs_info, # gs_info visibility + ], + ).then( + fn=lambda: "False", + inputs=[], + outputs=[is_example], # set is_example to "False" + ) + + # Real-time visualization updates + self._setup_visualization_handlers( + show_cam, + filter_black_bg, + filter_white_bg, + process_res_method_dropdown, + target_dir_output, + is_example, + reconstruction_output, + log_output, + ) + + # File upload handlers + input_video.change( + fn=self.event_handlers.handle_uploads, + inputs=[input_video, input_images, s_time_interval], + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + ) + input_images.change( + fn=self.event_handlers.handle_uploads, + inputs=[input_video, input_images, s_time_interval], + outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], + ) + + # Image gallery click handler (for selecting first frame) + def handle_image_selection(evt: gr.SelectData): + if evt is None or evt.index is None: + return "No image selected", 0 + selected_index = evt.index + return f"Selected image {selected_index} as potential first frame", selected_index + + image_gallery.select( + fn=handle_image_selection, + outputs=[log_output, selected_image_index_state], + ) + + # Select first frame handler + select_first_frame_btn.click( + fn=self.event_handlers.select_first_frame, + inputs=[image_gallery, selected_image_index_state], + outputs=[image_gallery, log_output, selected_first_frame_state], + ) + + # Navigation handlers + self._setup_navigation_handlers( + prev_measure_btn, + next_measure_btn, + measure_view_selector, + measure_image, + measure_depth_image, + measure_points_state, + processed_data_state, + ) + + # Measurement handler + measure_image.select( + fn=self.event_handlers.measure, + inputs=[processed_data_state, measure_points_state, measure_view_selector], + outputs=[measure_image, measure_depth_image, measure_points_state, measure_text], + ) + + # Example scene handlers + self._setup_example_scene_handlers( + scenes, + scene_components, + reconstruction_output, + target_dir_output, + image_gallery, + log_output, + is_example, + processed_data_state, + measure_view_selector, + measure_image, + measure_depth_image, + gs_video, + gs_info, + ) + + def _setup_visualization_handlers( + self, + show_cam: gr.Checkbox, + filter_black_bg: gr.Checkbox, + filter_white_bg: gr.Checkbox, + process_res_method_dropdown: gr.Dropdown, + target_dir_output: gr.Textbox, + is_example: gr.Textbox, + reconstruction_output: gr.Model3D, + log_output: gr.Markdown, + ) -> None: + """Set up visualization update handlers.""" + # Common inputs for visualization updates + viz_inputs = [ + target_dir_output, + show_cam, + is_example, + filter_black_bg, + filter_white_bg, + process_res_method_dropdown, + ] + + # Set up change handlers for all visualization controls + for component in [show_cam, filter_black_bg, filter_white_bg]: + component.change( + fn=self.event_handlers.update_visualization, + inputs=viz_inputs, + outputs=[reconstruction_output, log_output], + ) + + def _setup_navigation_handlers( + self, + prev_measure_btn: gr.Button, + next_measure_btn: gr.Button, + measure_view_selector: gr.Dropdown, + measure_image: gr.Image, + measure_depth_image: gr.Image, + measure_points_state: gr.State, + processed_data_state: gr.State, + ) -> None: + """Set up navigation handlers for measure tab.""" + # Measure tab navigation + prev_measure_btn.click( + fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view( + processed_data, current_selector, -1 + ), + inputs=[processed_data_state, measure_view_selector], + outputs=[ + measure_view_selector, + measure_image, + measure_depth_image, + measure_points_state, + ], + ) + + next_measure_btn.click( + fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view( + processed_data, current_selector, 1 + ), + inputs=[processed_data_state, measure_view_selector], + outputs=[ + measure_view_selector, + measure_image, + measure_depth_image, + measure_points_state, + ], + ) + + measure_view_selector.change( + fn=lambda processed_data, selector_value: ( + self.event_handlers.update_measure_view( + processed_data, int(selector_value.split()[1]) - 1 + ) + if selector_value + else (None, None, []) + ), + inputs=[processed_data_state, measure_view_selector], + outputs=[measure_image, measure_depth_image, measure_points_state], + ) + + def _setup_example_scene_handlers( + self, + scenes: List[Dict[str, Any]], + scene_components: List[gr.Image], + reconstruction_output: gr.Model3D, + target_dir_output: gr.Textbox, + image_gallery: gr.Gallery, + log_output: gr.Markdown, + is_example: gr.Textbox, + processed_data_state: gr.State, + measure_view_selector: gr.Dropdown, + measure_image: gr.Image, + measure_depth_image: gr.Image, + gs_video: gr.Video, + gs_info: gr.Markdown, + ) -> None: + """Set up example scene handlers.""" + + def load_and_update_measure(name): + result = self.event_handlers.load_example_scene(name) + # result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501 + + # Update measure view if processed_data is available + measure_img = None + measure_depth = None + if result[4] is not None: # processed_data exists + measure_img, measure_depth, _ = ( + self.event_handlers.visualization_handler.update_measure_view(result[4], 0) + ) + + return result + ("True", measure_img, measure_depth) + + for i, scene in enumerate(scenes): + if i < len(scene_components): + scene_components[i].select( + fn=lambda name=scene["name"]: load_and_update_measure(name), + outputs=[ + reconstruction_output, + target_dir_output, + image_gallery, + log_output, + processed_data_state, + measure_view_selector, + gs_video, + gs_video, # gs_video_visibility + gs_info, # gs_info_visibility + is_example, + measure_image, + measure_depth_image, + ], + ) + + def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None: + """ + Launch the application. + + Args: + host: Host address to bind to + port: Port number to bind to + **kwargs: Additional arguments for demo.launch() + """ + demo = self.create_app() + demo.queue(max_size=20).launch( + show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs + ) + + +def main(): + """Main function to run the application.""" + parser = argparse.ArgumentParser( + description="Depth Anything 3 Gradio Application", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic usage + python gradio_app.py --help + python gradio_app.py --host 0.0.0.0 --port 8080 + python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace + + # Cache examples at startup (all low-res) + python gradio_app.py --cache-examples + + # Cache with selective high-res+3DGS for scenes matching tag + python gradio_app.py --cache-examples --cache-gs-tag dl3dv + # This will use high-res + 3DGS for scenes containing "dl3dv" in their name, + # and low-res only for other scenes + """, + ) + + # Server configuration + parser.add_argument( + "--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)" + ) + parser.add_argument( + "--port", type=int, default=7860, help="Port number to bind to (default: 7860)" + ) + + # Directory configuration + parser.add_argument( + "--model-dir", + default="depth-anything/DA3NESTED-GIANT-LARGE", + help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)", + ) + parser.add_argument( + "--workspace-dir", + default="workspace/gradio", # noqa: E501 + help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501 + ) + parser.add_argument( + "--gallery-dir", + default="workspace/gallery", + help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501 + ) + + # Additional Gradio options + parser.add_argument("--share", action="store_true", help="Create a public link for the app") + parser.add_argument("--debug", action="store_true", help="Enable debug mode") + + # Example caching options + parser.add_argument( + "--cache-examples", + action="store_true", + help="Pre-cache all example scenes at startup for faster loading", + ) + parser.add_argument( + "--cache-gs-tag", + type=str, + default="", + help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501 + ) + + args = parser.parse_args() + + # Create directories if they don't exist + os.makedirs(args.workspace_dir, exist_ok=True) + os.makedirs(args.gallery_dir, exist_ok=True) + + # Initialize and launch the application + app = DepthAnything3App( + model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir + ) + + # Prepare launch arguments + launch_kwargs = {"share": args.share, "debug": args.debug} + + print("Starting Depth Anything 3 Gradio App...") + print(f"Host: {args.host}") + print(f"Port: {args.port}") + print(f"Model Directory: {args.model_dir}") + print(f"Workspace Directory: {args.workspace_dir}") + print(f"Gallery Directory: {args.gallery_dir}") + print(f"Share: {args.share}") + print(f"Debug: {args.debug}") + print(f"Cache Examples: {args.cache_examples}") + if args.cache_examples: + if args.cache_gs_tag: + print( + f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501 + ) # noqa: E501 + else: + print("Cache GS Tag: None (all scenes will use low-res only)") + + # Pre-cache examples if requested + if args.cache_examples: + print("\n" + "=" * 60) + print("Pre-caching mode enabled") + if args.cache_gs_tag: + print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS") + print("Other scenes will use LOW-RES only") + else: + print("All scenes will use LOW-RES only") + print("=" * 60) + app.cache_examples( + show_cam=True, + filter_black_bg=False, + filter_white_bg=False, + save_percentage=5.0, + num_max_points=1000, + cache_gs_tag=args.cache_gs_tag, + gs_trj_mode="smooth", + gs_video_quality="low", + ) + + app.launch(host=args.host, port=args.port, **launch_kwargs) + + +if __name__ == "__main__": + main() diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0b71780214eeadbc4fc44f4ac070e0e5fa7795 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Modules package for Depth Anything 3 Gradio app. + +This package contains all the modular components for the Gradio application. +""" + +from depth_anything_3.app.modules.event_handlers import EventHandlers +from depth_anything_3.app.modules.file_handlers import FileHandler +from depth_anything_3.app.modules.model_inference import ModelInference +from depth_anything_3.app.modules.ui_components import UIComponents +from depth_anything_3.app.modules.utils import ( + create_depth_visualization, + get_logo_base64, + get_scene_info, + save_to_gallery_func, +) +from depth_anything_3.app.modules.visualization import VisualizationHandler + +__all__ = [ + "ModelInference", + "FileHandler", + "VisualizationHandler", + "EventHandlers", + "UIComponents", + "create_depth_visualization", + "save_to_gallery_func", + "get_scene_info", + "get_logo_base64", +] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/event_handlers.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/event_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a11498dfe9a0c7394bbd7f2abf943ffc0b702f --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/event_handlers.py @@ -0,0 +1,629 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Event handling module for Depth Anything 3 Gradio app. + +This module handles all event callbacks and user interactions. +""" + +import os +import time +from glob import glob +from typing import Any, Dict, List, Optional, Tuple +import gradio as gr +import numpy as np +import torch + +from depth_anything_3.app.modules.file_handlers import FileHandler +from depth_anything_3.app.modules.model_inference import ModelInference +from depth_anything_3.utils.memory import cleanup_cuda_memory +from depth_anything_3.app.modules.visualization import VisualizationHandler + + +class EventHandlers: + """ + Handles all event callbacks and user interactions for the Gradio app. + """ + + def __init__(self): + """Initialize the event handlers.""" + self.model_inference = ModelInference() + self.file_handler = FileHandler() + self.visualization_handler = VisualizationHandler() + + def clear_fields(self) -> None: + """ + Clears the 3D viewer, the stored target_dir, and empties the gallery. + """ + return None + + def update_log(self) -> str: + """ + Display a quick log message while waiting. + """ + return "Loading and Reconstructing..." + + def save_current_visualization( + self, + target_dir: str, + save_percentage: float, + show_cam: bool, + filter_black_bg: bool, + filter_white_bg: bool, + processed_data: Optional[Dict], + scene_name: str = "", + ) -> str: + """ + Save current visualization results to gallery with specified save percentage. + + Args: + target_dir: Directory containing results + save_percentage: Percentage of points to save (0-100) + show_cam: Whether to show cameras + filter_black_bg: Whether to filter black background + filter_white_bg: Whether to filter white background + processed_data: Processed data from reconstruction + + Returns: + Status message + """ + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return "No reconstruction available. Please run 'Reconstruct' first." + + if processed_data is None: + return "No processed data available. Please run 'Reconstruct' first." + + try: + # Add debug information + print("[DEBUG] save_current_visualization called with:") + print(f" target_dir: {target_dir}") + print(f" save_percentage: {save_percentage}") + print(f" show_cam: {show_cam}") + print(f" filter_black_bg: {filter_black_bg}") + print(f" filter_white_bg: {filter_white_bg}") + print(f" processed_data: {processed_data is not None}") + + # Import the gallery save function + # Create gallery name with user input or auto-generated + import datetime + + from .utils import save_to_gallery_func + + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + if scene_name and scene_name.strip(): + gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}" + else: + gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}" + + print(f"[DEBUG] Saving to gallery with name: {gallery_name}") + + # Save entire process folder to gallery + success, message = save_to_gallery_func( + target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name + ) + + if success: + print(f"[DEBUG] Gallery save completed successfully: {message}") + return ( + "Successfully saved to gallery!\n" + f"Gallery name: {gallery_name}\n" + f"Save percentage: {save_percentage}%\n" + f"Show cameras: {show_cam}\n" + f"Filter black bg: {filter_black_bg}\n" + f"Filter white bg: {filter_white_bg}\n\n" + f"{message}" + ) + else: + print(f"[DEBUG] Gallery save failed: {message}") + return f"Failed to save to gallery: {message}" + + except Exception as e: + return f"Error saving visualization: {str(e)}" + + def gradio_demo( + self, + target_dir: str, + show_cam: bool = True, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + process_res_method: str = "keep", + selected_first_frame: str = "", + save_percentage: float = 30.0, + num_max_points: int = 1_000_000, + infer_gs: bool = False, + gs_trj_mode: str = "extend", + gs_video_quality: str = "high", + ) -> Tuple[ + Optional[str], + str, + Optional[Dict], + Optional[np.ndarray], + Optional[np.ndarray], + str, + gr.Dropdown, + Optional[str], # gs video path + gr.update, # gs video visibility update + gr.update, # gs info visibility update + ]: + """ + Perform reconstruction using the already-created target_dir/images. + + Args: + target_dir: Directory containing images + show_cam: Whether to show camera + filter_black_bg: Whether to filter black background + filter_white_bg: Whether to filter white background + process_res_method: Method for resizing input images + selected_first_frame: Selected first frame filename + infer_gs: Whether to infer 3D Gaussian Splatting + + Returns: + Tuple of reconstruction results + """ + if not os.path.isdir(target_dir) or target_dir == "None": + return ( + None, + "No valid target directory found. Please upload first.", + None, + None, + None, + "", + None, + None, + gr.update(visible=False), # gs_video + gr.update(visible=True), # gs_info + ) + + start_time = time.time() + cleanup_cuda_memory() + + # Get image files for logging + target_dir_images = os.path.join(target_dir, "images") + all_files = ( + sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + ) + + print("Running DepthAnything3 model...") + print(f"Selected first frame: {selected_first_frame}") + + # Validate selected_first_frame against current image list + if selected_first_frame and target_dir_images: + current_files = ( + sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] + ) + if selected_first_frame not in current_files: + print( + f"Selected first frame '{selected_first_frame}' not found in " + "current images. Using default order." + ) + selected_first_frame = "" # Reset to use default order + + with torch.no_grad(): + prediction, processed_data = self.model_inference.run_inference( + target_dir, + process_res_method=process_res_method, + show_camera=show_cam, + selected_first_frame=selected_first_frame, + save_percentage=save_percentage, + num_max_points=int(num_max_points * 1000), # Convert K to actual count + infer_gs=infer_gs, + gs_trj_mode=gs_trj_mode, + gs_video_quality=gs_video_quality, + ) + + # The GLB file is already generated by the API + glbfile = os.path.join(target_dir, "scene.glb") + + # Handle 3DGS video based on infer_gs flag + gsvideo_path = None + gs_video_visible = False + gs_info_visible = True + + if infer_gs: + try: + gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1] + gs_video_visible = True + gs_info_visible = False + except IndexError: + gsvideo_path = None + print("3DGS video not found, but infer_gs was enabled") + + # Cleanup + cleanup_cuda_memory() + + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." + + # Populate visualization tabs with processed data + depth_vis, measure_img, measure_depth_vis, measure_pts = ( + self.visualization_handler.populate_visualization_tabs(processed_data) + ) + + # Update view selectors based on available views + depth_selector, measure_selector = self.visualization_handler.update_view_selectors( + processed_data + ) + + return ( + glbfile, + log_msg, + processed_data, + measure_img, # measure_image + measure_depth_vis, # measure_depth_image + "", # measure_text (empty initially) + measure_selector, # measure_view_selector + gsvideo_path, + gr.update(visible=gs_video_visible), # gs_video visibility + gr.update(visible=gs_info_visible), # gs_info visibility + ) + + def update_visualization( + self, + target_dir: str, + show_cam: bool, + is_example: str, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + process_res_method: str = "keep", + ) -> Tuple[gr.update, str]: + """ + Reload saved predictions from npz, create (or reuse) the GLB for new parameters, + and return it for the 3D viewer. + + Args: + target_dir: Directory containing results + show_cam: Whether to show camera + is_example: Whether this is an example scene + filter_black_bg: Whether to filter black background + filter_white_bg: Whether to filter white background + process_res_method: Method for resizing input images + + Returns: + Tuple of (glb_file, log_message) + """ + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return ( + gr.update(), + "No reconstruction available. Please click the Reconstruct button first.", + ) + + # Check if GLB exists (could be cached example or reconstructed scene) + glbfile = os.path.join(target_dir, "scene.glb") + if os.path.exists(glbfile): + return ( + glbfile, + ( + "Visualization loaded from cache." + if is_example == "True" + else "Visualization updated." + ), + ) + + # If no GLB but it's an example that hasn't been reconstructed yet + if is_example == "True": + return ( + gr.update(), + "No reconstruction available. Please click the Reconstruct button first.", + ) + + # For non-examples, check predictions.npz + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + error_message = ( + f"No reconstruction available at {predictions_path}. " + "Please run 'Reconstruct' first." + ) + return gr.update(), error_message + + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841 + + return ( + glbfile, + "Visualization updated.", + ) + + def handle_uploads( + self, + input_video: Optional[str], + input_images: Optional[List], + s_time_interval: float = 10.0, + ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]: + """ + Handle file uploads and update gallery. + + Args: + input_video: Path to input video file + input_images: List of input image files + s_time_interval: Sampling FPS (frames per second) for frame extraction + + Returns: + Tuple of (reconstruction_output, target_dir, image_paths, log_message) + """ + return self.file_handler.update_gallery_on_upload( + input_video, input_images, s_time_interval + ) + + def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[ + Optional[str], + Optional[str], + Optional[List], + str, + Optional[Dict], + gr.Dropdown, + Optional[str], + gr.update, + gr.update, + ]: + """ + Load a scene from examples directory. + + Args: + scene_name: Name of the scene to load + examples_dir: Path to examples directory (if None, uses workspace_dir/examples) + + Returns: + Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501 + """ + if examples_dir is None: + # Get workspace directory from environment variable + workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace") + examples_dir = os.path.join(workspace_dir, "examples") + + reconstruction_output, target_dir, image_paths, log_message = ( + self.file_handler.load_example_scene(scene_name, examples_dir) + ) + + # Try to load cached processed data if available + processed_data = None + measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1") + gs_video_path = None + gs_video_visible = False + gs_info_visible = True + + if target_dir and target_dir != "None": + predictions_path = os.path.join(target_dir, "predictions.npz") + if os.path.exists(predictions_path): + try: + # Load predictions from cache + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} + + # Reconstruct processed_data structure + num_images = len(predictions.get("images", [])) + processed_data = {} + + for i in range(num_images): + processed_data[i] = { + "image": predictions["images"][i] if "images" in predictions else None, + "depth": predictions["depths"][i] if "depths" in predictions else None, + "depth_image": os.path.join( + target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png + ), + "intrinsics": ( + predictions["intrinsics"][i] + if "intrinsics" in predictions + and i < len(predictions["intrinsics"]) + else None + ), + "mask": None, + } + + # Update measure view selector + choices = [f"View {i + 1}" for i in range(num_images)] + measure_view_selector = gr.Dropdown(choices=choices, value=choices[0]) + + except Exception as e: + print(f"Error loading cached data: {e}") + + # Check for cached 3DGS video + gs_video_dir = os.path.join(target_dir, "gs_video") + if os.path.exists(gs_video_dir): + try: + from glob import glob + + gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4"))) + if gs_videos: + gs_video_path = gs_videos[-1] + gs_video_visible = True + gs_info_visible = False + print(f"Loaded cached 3DGS video: {gs_video_path}") + except Exception as e: + print(f"Error loading cached 3DGS video: {e}") + + return ( + reconstruction_output, + target_dir, + image_paths, + log_message, + processed_data, + measure_view_selector, + gs_video_path, + gr.update(visible=gs_video_visible), + gr.update(visible=gs_info_visible), + ) + + def navigate_depth_view( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + current_selector: str, + direction: int, + ) -> Tuple[str, Optional[str]]: + """ + Navigate depth view. + + Args: + processed_data: Processed data dictionary + current_selector: Current selector value + direction: Direction to navigate + + Returns: + Tuple of (new_selector_value, depth_vis) + """ + return self.visualization_handler.navigate_depth_view( + processed_data, current_selector, direction + ) + + def update_depth_view( + self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int + ) -> Optional[str]: + """ + Update depth view for a specific view index. + + Args: + processed_data: Processed data dictionary + view_index: Index of the view to update + + Returns: + Path to depth visualization image or None + """ + return self.visualization_handler.update_depth_view(processed_data, view_index) + + def navigate_measure_view( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + current_selector: str, + direction: int, + ) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]: + """ + Navigate measure view. + + Args: + processed_data: Processed data dictionary + current_selector: Current selector value + direction: Direction to navigate + + Returns: + Tuple of (new_selector_value, measure_image, depth_right_half, measure_points) + """ + return self.visualization_handler.navigate_measure_view( + processed_data, current_selector, direction + ) + + def update_measure_view( + self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]: + """ + Update measure view for a specific view index. + + Args: + processed_data: Processed data dictionary + view_index: Index of the view to update + + Returns: + Tuple of (measure_image, depth_right_half, measure_points) + """ + return self.visualization_handler.update_measure_view(processed_data, view_index) + + def measure( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + measure_points: List, + current_view_selector: str, + event: gr.SelectData, + ) -> List: + """ + Handle measurement on images. + + Args: + processed_data: Processed data dictionary + measure_points: List of current measure points + current_view_selector: Current view selector value + event: Gradio select event + + Returns: + List of [image, depth_right_half, measure_points, text] + """ + return self.visualization_handler.measure( + processed_data, measure_points, current_view_selector, event + ) + + def select_first_frame( + self, image_gallery: List, selected_index: int = 0 + ) -> Tuple[List, str, str]: + """ + Select the first frame from the image gallery. + + Args: + image_gallery: List of images in the gallery + selected_index: Index of the selected image (default: 0) + + Returns: + Tuple of (updated_image_gallery, log_message, selected_frame_path) + """ + try: + if not image_gallery or len(image_gallery) == 0: + return image_gallery, "No images available to select as first frame.", "" + + # Handle None or invalid selected_index + if ( + selected_index is None + or selected_index < 0 + or selected_index >= len(image_gallery) + ): + selected_index = 0 + print(f"Invalid selected_index: {selected_index}, using default: 0") + + # Get the selected image based on index + selected_image = image_gallery[selected_index] + print(f"Selected image index: {selected_index}") + print(f"Total images: {len(image_gallery)}") + + # Extract the file path from the selected image + selected_frame_path = "" + print(f"Selected image type: {type(selected_image)}") + print(f"Selected image: {selected_image}") + + if isinstance(selected_image, tuple): + # Gradio Gallery returns tuple (path, None) + selected_frame_path = selected_image[0] + elif isinstance(selected_image, str): + selected_frame_path = selected_image + elif hasattr(selected_image, "name"): + selected_frame_path = selected_image.name + elif isinstance(selected_image, dict): + if "name" in selected_image: + selected_frame_path = selected_image["name"] + elif "path" in selected_image: + selected_frame_path = selected_image["path"] + elif "src" in selected_image: + selected_frame_path = selected_image["src"] + else: + # Try to convert to string + selected_frame_path = str(selected_image) + + print(f"Extracted path: {selected_frame_path}") + + # Extract filename from the path for matching + import os + + selected_filename = os.path.basename(selected_frame_path) + print(f"Selected filename: {selected_filename}") + + # Move the selected image to the front + updated_gallery = [selected_image] + [ + img for img in image_gallery if img != selected_image + ] + + log_message = ( + f"Selected frame: {selected_filename}. " + f"Moved to first position. Total frames: {len(updated_gallery)}" + ) + return updated_gallery, log_message, selected_filename + + except Exception as e: + print(f"Error selecting first frame: {e}") + return image_gallery, f"Error selecting first frame: {e}", "" diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/file_handlers.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/file_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..d738bfeb1e70ab2af752e022899d1bda91eb397d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/file_handlers.py @@ -0,0 +1,304 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +File handling module for Depth Anything 3 Gradio app. + +This module handles file uploads, video processing, and file operations. +""" + +import os +import shutil +import time +from datetime import datetime +from typing import List, Optional, Tuple +import cv2 +from PIL import Image +from pillow_heif import register_heif_opener + +register_heif_opener() + + +class FileHandler: + """ + Handles file uploads and processing for the Gradio app. + """ + + def __init__(self): + """Initialize the file handler.""" + + def handle_uploads( + self, + input_video: Optional[str], + input_images: Optional[List], + s_time_interval: float = 10.0, + ) -> Tuple[str, List[str]]: + """ + Create a new 'target_dir' + 'images' subfolder, and place user-uploaded + images or extracted frames from video into it. + + Args: + input_video: Path to input video file + input_images: List of input image files + s_time_interval: Sampling FPS (frames per second) for frame extraction + + Returns: + Tuple of (target_dir, image_paths) + """ + start_time = time.time() + + # Get workspace directory from environment variable or use default + workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace") + if not os.path.exists(workspace_dir): + os.makedirs(workspace_dir) + + # Create input_images subdirectory + input_images_dir = os.path.join(workspace_dir, "input_images") + if not os.path.exists(input_images_dir): + os.makedirs(input_images_dir) + + # Create a unique folder name within input_images + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + target_dir = os.path.join(input_images_dir, f"session_{timestamp}") + target_dir_images = os.path.join(target_dir, "images") + + # Clean up if somehow that folder already exists + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + os.makedirs(target_dir_images) + + image_paths = [] + + # Handle images + if input_images is not None: + image_paths.extend(self._process_images(input_images, target_dir_images)) + + # Handle video + if input_video is not None: + image_paths.extend( + self._process_video(input_video, target_dir_images, s_time_interval) + ) + + # Sort final images for gallery + image_paths = sorted(image_paths) + + end_time = time.time() + print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") + return target_dir, image_paths + + def _process_images(self, input_images: List, target_dir_images: str) -> List[str]: + """ + Process uploaded images. + + Args: + input_images: List of input image files + target_dir_images: Target directory for images + + Returns: + List of processed image paths + """ + image_paths = [] + + for file_data in input_images: + if isinstance(file_data, dict) and "name" in file_data: + file_path = file_data["name"] + else: + file_path = file_data + + # Check if the file is a HEIC image + file_ext = os.path.splitext(file_path)[1].lower() + if file_ext in [".heic", ".heif"]: + # Convert HEIC to JPEG for better gallery compatibility + try: + with Image.open(file_path) as img: + # Convert to RGB if necessary (HEIC can have different color modes) + if img.mode not in ("RGB", "L"): + img = img.convert("RGB") + + # Create JPEG filename + base_name = os.path.splitext(os.path.basename(file_path))[0] + dst_path = os.path.join(target_dir_images, f"{base_name}.jpg") + + # Save as JPEG with high quality + img.save(dst_path, "JPEG", quality=95) + image_paths.append(dst_path) + print( + f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> " + f"{os.path.basename(dst_path)}" + ) + except Exception as e: + print(f"Error converting HEIC file {file_path}: {e}") + # Fall back to copying as is + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + else: + # Regular image files - copy as is + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + + return image_paths + + def _process_video( + self, input_video: str, target_dir_images: str, s_time_interval: float + ) -> List[str]: + """ + Process video file and extract frames. + + Args: + input_video: Path to input video file + target_dir_images: Target directory for extracted frames + s_time_interval: Sampling FPS (frames per second) for frame extraction + + Returns: + List of extracted frame paths + """ + image_paths = [] + + if isinstance(input_video, dict) and "name" in input_video: + video_path = input_video["name"] + else: + video_path = input_video + + vs = cv2.VideoCapture(video_path) + fps = vs.get(cv2.CAP_PROP_FPS) + frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval + + count = 0 + video_frame_num = 0 + while True: + gotit, frame = vs.read() + if not gotit: + break + count += 1 + if count % frame_interval == 0: + image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") + cv2.imwrite(image_path, frame) + image_paths.append(image_path) + video_frame_num += 1 + + return image_paths + + def update_gallery_on_upload( + self, + input_video: Optional[str], + input_images: Optional[List], + s_time_interval: float = 10.0, + ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]: + """ + Handle file uploads and update gallery. + + Args: + input_video: Path to input video file + input_images: List of input image files + s_time_interval: Sampling FPS (frames per second) for frame extraction + + Returns: + Tuple of (reconstruction_output, target_dir, image_paths, log_message) + """ + if not input_video and not input_images: + return None, None, None, None + + target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval) + return ( + None, + target_dir, + image_paths, + "Upload complete. Click 'Reconstruct' to begin 3D processing.", + ) + + def load_example_scene( + self, scene_name: str, examples_dir: str = "examples" + ) -> Tuple[Optional[str], Optional[str], Optional[List], str]: + """ + Load a scene from examples directory. + + Args: + scene_name: Name of the scene to load + examples_dir: Path to examples directory + + Returns: + Tuple of (reconstruction_output, target_dir, image_paths, log_message) + """ + from depth_anything_3.app.modules.utils import get_scene_info + + scenes = get_scene_info(examples_dir) + + # Find the selected scene + selected_scene = None + for scene in scenes: + if scene["name"] == scene_name: + selected_scene = scene + break + + if selected_scene is None: + return None, None, None, "Scene not found" + + # Use fixed directory name for examples (not timestamp-based) + workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace") + input_images_dir = os.path.join(workspace_dir, "input_images") + if not os.path.exists(input_images_dir): + os.makedirs(input_images_dir) + + # Create a fixed folder name based on scene name + target_dir = os.path.join(input_images_dir, f"example_{scene_name}") + target_dir_images = os.path.join(target_dir, "images") + + # Check if already cached (GLB file exists) + glb_path = os.path.join(target_dir, "scene.glb") + is_cached = os.path.exists(glb_path) + + # Create directory if it doesn't exist + if not os.path.exists(target_dir): + os.makedirs(target_dir) + os.makedirs(target_dir_images) + + # Copy images if directory is new or empty + if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0: + os.makedirs(target_dir_images, exist_ok=True) + image_paths = [] + for file_path in selected_scene["image_files"]: + dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) + shutil.copy(file_path, dst_path) + image_paths.append(dst_path) + else: + # Use existing images + image_paths = sorted( + [ + os.path.join(target_dir_images, f) + for f in os.listdir(target_dir_images) + if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif")) + ] + ) + + # Return cached GLB if available + if is_cached: + return ( + glb_path, # Return cached reconstruction + target_dir, # Set target directory + image_paths, # Set gallery + f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.", + ) + else: + return ( + None, # No cached reconstruction + target_dir, # Set target directory + image_paths, # Set gallery + ( + f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. " + "Click 'Reconstruct' to begin 3D processing." + ), + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/model_inference.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/model_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..055b31e546cc0e7a540ec0b173e4ef130426f324 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/model_inference.py @@ -0,0 +1,292 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model inference module for Depth Anything 3 Gradio app. + +This module handles all model-related operations including inference, +data processing, and result preparation. +""" + +import glob +import os +from typing import Any, Dict, Optional, Tuple +import numpy as np +import torch + +from depth_anything_3.api import DepthAnything3 +from depth_anything_3.utils.memory import cleanup_cuda_memory +from depth_anything_3.utils.export.glb import export_to_glb +from depth_anything_3.utils.export.gs import export_to_gs_video + + +class ModelInference: + """ + Handles model inference and data processing for Depth Anything 3. + """ + + def __init__(self): + """Initialize the model inference handler.""" + self.model = None + + def initialize_model(self, device: str = "cuda") -> None: + """ + Initialize the DepthAnything3 model. + + Args: + device: Device to load the model on + """ + if self.model is None: + # Get model directory from environment variable or use default + model_dir = os.environ.get( + "DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL" + ) + self.model = DepthAnything3.from_pretrained(model_dir) + self.model = self.model.to(device) + else: + self.model = self.model.to(device) + + self.model.eval() + + def run_inference( + self, + target_dir: str, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + process_res_method: str = "keep", + show_camera: bool = True, + selected_first_frame: Optional[str] = None, + save_percentage: float = 30.0, + num_max_points: int = 1_000_000, + infer_gs: bool = False, + gs_trj_mode: str = "extend", + gs_video_quality: str = "high", + ) -> Tuple[Any, Dict[int, Dict[str, Any]]]: + """ + Run DepthAnything3 model inference on images. + + Args: + target_dir: Directory containing images + apply_mask: Whether to apply mask for ambiguous depth classes + mask_edges: Whether to mask edges + filter_black_bg: Whether to filter black background + filter_white_bg: Whether to filter white background + process_res_method: Method for resizing input images + show_camera: Whether to show camera in 3D view + selected_first_frame: Selected first frame filename + save_percentage: Percentage of points to save (0-100) + infer_gs: Whether to infer 3D Gaussian Splatting + + Returns: + Tuple of (prediction, processed_data) + """ + print(f"Processing images from {target_dir}") + + # Device check + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + # Initialize model if needed + self.initialize_model(device) + + # Get image paths + print("Loading images...") + image_folder_path = os.path.join(target_dir, "images") + all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*"))) + + # Filter for image files + image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"] + all_image_paths = [ + path + for path in all_image_paths + if any(path.lower().endswith(ext) for ext in image_extensions) + ] + + print(f"Found {len(all_image_paths)} images") + print(f"All image paths: {all_image_paths}") + + # Apply first frame selection logic + if selected_first_frame: + # Find the image with matching filename + selected_path = None + for path in all_image_paths: + if os.path.basename(path) == selected_first_frame: + selected_path = path + break + + if selected_path: + # Move selected frame to the front + image_paths = [selected_path] + [ + path for path in all_image_paths if path != selected_path + ] + print(f"User selected first frame: {selected_first_frame} -> {selected_path}") + print(f"Reordered image paths: {image_paths}") + else: + # Use default order if no match found + image_paths = all_image_paths + print( + f"Selected frame '{selected_first_frame}' not found in image paths. " + "Using default order." + ) + first_frame_display = image_paths[0] if image_paths else "No images" + print(f"Using default order (first frame): {first_frame_display}") + else: + # Use default order (sorted) + image_paths = all_image_paths + first_frame_display = image_paths[0] if image_paths else "No images" + print(f"Using default order (first frame): {first_frame_display}") + + if len(image_paths) == 0: + raise ValueError("No images found. Check your upload.") + + # Map UI options to actual method names + method_mapping = { + "high_res": "lower_bound_resize", + "low_res": "upper_bound_resize", + "keep": "keep", + "original": "original", + } + actual_method = method_mapping.get(process_res_method, process_res_method) + process_res_value = None if actual_method in ("keep", "original") else 504 + + # Run model inference + print(f"Running inference with method: {actual_method}") + with torch.no_grad(): + prediction = self.model.inference( + image_paths, + export_dir=None, + process_res=process_res_value, + process_res_method=actual_method, + infer_gs=infer_gs, + ) + # num_max_points: int = 1_000_000, + export_to_glb( + prediction, + filter_black_bg=filter_black_bg, + filter_white_bg=filter_white_bg, + export_dir=target_dir, + show_cameras=show_camera, + conf_thresh_percentile=save_percentage, + num_max_points=int(num_max_points), + ) + + # export to gs video if needed + if infer_gs: + mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"} + print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}") + export_to_gs_video( + prediction, + export_dir=target_dir, + chunk_size=4, + trj_mode=mode_mapping.get(gs_trj_mode, "extend"), + enable_tqdm=True, + vis_depth="hcat", + video_quality=gs_video_quality, + ) + + # Save predictions.npz for caching metric depth data + self._save_predictions_cache(target_dir, prediction) + + # Process results + processed_data = self._process_results(target_dir, prediction, image_paths) + + # Clean up using centralized memory utilities for consistency with backend + cleanup_cuda_memory() + + return prediction, processed_data + + def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None: + """ + Save predictions data to predictions.npz for caching. + + Args: + target_dir: Directory to save the cache + prediction: Model prediction object + """ + try: + output_file = os.path.join(target_dir, "predictions.npz") + + # Build save dict with prediction data + save_dict = {} + + # Save processed images if available + if prediction.processed_images is not None: + save_dict["images"] = prediction.processed_images + + # Save depth data + if prediction.depth is not None: + save_dict["depths"] = np.round(prediction.depth, 6) + + # Save confidence if available + if prediction.conf is not None: + save_dict["conf"] = np.round(prediction.conf, 2) + + # Save camera parameters + if prediction.extrinsics is not None: + save_dict["extrinsics"] = prediction.extrinsics + if prediction.intrinsics is not None: + save_dict["intrinsics"] = prediction.intrinsics + + # Save to file + np.savez_compressed(output_file, **save_dict) + print(f"Saved predictions cache to: {output_file}") + + except Exception as e: + print(f"Warning: Failed to save predictions cache: {e}") + + def _process_results( + self, target_dir: str, prediction: Any, image_paths: list + ) -> Dict[int, Dict[str, Any]]: + """ + Process model results into structured data. + + Args: + target_dir: Directory containing results + prediction: Model prediction object + image_paths: List of input image paths + + Returns: + Dictionary containing processed data for each view + """ + processed_data = {} + + # Read generated depth visualization files + depth_vis_dir = os.path.join(target_dir, "depth_vis") + + if os.path.exists(depth_vis_dir): + depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg"))) + for i, depth_file in enumerate(depth_files): + # Use processed images directly from API + processed_image = None + if prediction.processed_images is not None and i < len( + prediction.processed_images + ): + processed_image = prediction.processed_images[i] + + processed_data[i] = { + "depth_image": depth_file, + "image": processed_image, + "original_image_path": image_paths[i] if i < len(image_paths) else None, + "depth": prediction.depth[i] if i < len(prediction.depth) else None, + "intrinsics": ( + prediction.intrinsics[i] + if prediction.intrinsics is not None and i < len(prediction.intrinsics) + else None + ), + "mask": None, # No mask information available + } + + return processed_data + + # cleanup() removed: call cleanup_cuda_memory() directly where needed. diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/ui_components.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/ui_components.py new file mode 100644 index 0000000000000000000000000000000000000000..1f4b25cbe7bc024b3af0cd835f1140315d8f7389 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/ui_components.py @@ -0,0 +1,474 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +UI components module for Depth Anything 3 Gradio app. + +This module contains UI component definitions and layout functions. +""" + +import os +from typing import Any, Dict, List, Tuple +import gradio as gr + +from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info + + +class UIComponents: + """ + Handles UI component creation and layout for the Gradio app. + """ + + def __init__(self): + """Initialize the UI components handler.""" + + def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery, gr.Button]: + """ + Create the upload section with video, images, and gallery components. + + Returns: + A tuple of Gradio components: (input_video, s_time_interval, input_images, + image_gallery, select_first_frame_btn). + """ + input_video = gr.Video(label="Upload Video", interactive=True) + s_time_interval = gr.Slider( + minimum=0.1, + maximum=60, + value=10, + step=0.1, + label="Sampling FPS (Frames Per Second)", + interactive=True, + visible=True, + ) + input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) + image_gallery = gr.Gallery( + label="Preview", + columns=4, + height="300px", + show_download_button=True, + object_fit="contain", + preview=True, + interactive=False, + ) + + # Select first frame button (moved below image gallery) + select_first_frame_btn = gr.Button("Select First Frame", scale=1) + + return input_video, s_time_interval, input_images, image_gallery, select_first_frame_btn + + def create_3d_viewer_section(self) -> gr.Model3D: + """ + Create the 3D viewer component. + + Returns: + 3D model viewer component + """ + return gr.Model3D( + height=520, + zoom_speed=0.5, + pan_speed=0.5, + clear_color=[0.0, 0.0, 0.0, 0.0], + key="persistent_3d_viewer", + elem_id="reconstruction_3d_viewer", + ) + + def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]: + """ + Create the 3DGS rendered video display component and info message. + + Returns: + Tuple of (video component, info message component) + """ + with gr.Column(): + gs_info = gr.Markdown( + ( + "‼️ **3D Gaussian Splatting rendering is currently DISABLED.**


" + "To render novel views from 3DGS, " + "enable **Infer 3D Gaussian Splatting** below.
" + "Next, in **Visualization Options**, " + "*optionally* configure the **rendering trajectory** (default: smooth) " + "and **video quality** (default: low), " + "then click **Reconstruct**." + ), + visible=True, + height=520, + ) + gs_video = gr.Video( + height=520, + label="3DGS Rendered NVS Video (depth shown for reference only)", + interactive=False, + visible=False, + ) + return gs_video, gs_info + + def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]: + """ + Create the depth visualization section. + + Returns: + A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map) + """ + with gr.Row(elem_classes=["navigation-row"]): + prev_depth_btn = gr.Button("β—€ Previous", size="sm", scale=1) + depth_view_selector = gr.Dropdown( + choices=["View 1"], + value="View 1", + label="Select View", + scale=2, + interactive=True, + allow_custom_value=True, + ) + next_depth_btn = gr.Button("Next β–Ά", size="sm", scale=1) + depth_map = gr.Image( + type="numpy", + label="Colorized Depth Map", + format="png", + interactive=False, + ) + + return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map + + def create_measure_section( + self, + ) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]: + """ + Create the measurement section. + + Returns: + A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image, + measure_depth_image, measure_text) + """ + from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML + + gr.Markdown(MEASURE_INSTRUCTIONS_HTML) + with gr.Row(elem_classes=["navigation-row"]): + prev_measure_btn = gr.Button("β—€ Previous", size="sm", scale=1) + measure_view_selector = gr.Dropdown( + choices=["View 1"], + value="View 1", + label="Select View", + scale=2, + interactive=True, + allow_custom_value=True, + ) + next_measure_btn = gr.Button("Next β–Ά", size="sm", scale=1) + with gr.Row(): + measure_image = gr.Image( + type="numpy", + show_label=False, + format="webp", + interactive=False, + sources=[], + label="RGB Image", + scale=1, + height=275, + ) + measure_depth_image = gr.Image( + type="numpy", + show_label=False, + format="webp", + interactive=False, + sources=[], + label="Depth Visualization (Right Half)", + scale=1, + height=275, + ) + gr.Markdown( + "**Note:** Images have been adjusted to model processing size. " + "Click two points on the RGB image to measure distance." + ) + measure_text = gr.Markdown("") + + return ( + prev_measure_btn, + measure_view_selector, + next_measure_btn, + measure_image, + measure_depth_image, + measure_text, + ) + + def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox]: + """ + Create the inference control section (before inference). + + Returns: + Tuple of (process_res_method_dropdown, infer_gs) + """ + with gr.Row(): + process_res_method_dropdown = gr.Dropdown( + choices=["high_res", "low_res"], + value="low_res", + label="Image Processing Method", + info="low_res for much more images", + scale=1, + ) + # Modify line 220, add color class + infer_gs = gr.Checkbox( + label="Infer 3D Gaussian Splatting", + value=False, + info=( + 'Enable novel view rendering from 3DGS ( requires extra processing time)' + ), + scale=1, + ) + + return (process_res_method_dropdown, infer_gs) + + def create_display_control_section( + self, + ) -> Tuple[ + gr.Checkbox, + gr.Checkbox, + gr.Checkbox, + gr.Slider, + gr.Slider, + gr.Dropdown, + gr.Dropdown, + gr.Button, + gr.ClearButton, + ]: + """ + Create the display control section (options for visualization). + + Returns: + Tuple of display control components including buttons + """ + with gr.Column(): + # 3DGS options at the top + with gr.Row(): + gs_trj_mode = gr.Dropdown( + choices=["smooth", "extend"], + value="smooth", + label=("Rendering trajectory for 3DGS viewpoints (requires n_views β‰₯ 2)"), + info=("'smooth' for view interpolation; 'extend' for longer trajectory"), + visible=False, # initially hidden + ) + gs_video_quality = gr.Dropdown( + choices=["low", "medium", "high"], + value="low", + label=("Video quality for 3DGS rendered outputs"), + info=("'low' for faster loading speed; 'high' for better visual quality"), + visible=False, # initially hidden + ) + + # Reconstruct and Clear buttons (before Visualization Options) + with gr.Row(): + submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") + clear_btn = gr.ClearButton(scale=1) + + gr.Markdown("### Visualization Options: (Click Reconstruct to update)") + show_cam = gr.Checkbox(label="Show Camera", value=True) + filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False) + filter_white_bg = gr.Checkbox(label="Filter White Background", value=False) + save_percentage = gr.Slider( + minimum=0, + maximum=100, + value=10, + step=1, + label="Filter Percentage", + info="Confidence Threshold (%): Higher values filter more points.", + ) + num_max_points = gr.Slider( + minimum=1000, + maximum=100000, + value=1000, + step=1000, + label="Max Points (K points)", + info="Maximum number of points to export to GLB (in thousands)", + ) + + return ( + show_cam, + filter_black_bg, + filter_white_bg, + save_percentage, + num_max_points, + gs_trj_mode, + gs_video_quality, + submit_btn, + clear_btn, + ) + + def create_control_section( + self, + ) -> Tuple[ + gr.Button, + gr.ClearButton, + gr.Dropdown, + gr.Checkbox, + gr.Checkbox, + gr.Checkbox, + gr.Checkbox, + gr.Checkbox, + gr.Dropdown, + gr.Checkbox, + gr.Textbox, + ]: + """ + Create the control section with buttons and options. + + Returns: + Tuple of control components + """ + with gr.Row(): + submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") + clear_btn = gr.ClearButton( + scale=1, + ) + + with gr.Row(): + frame_filter = gr.Dropdown( + choices=["All"], value="All", label="Show Points from Frame" + ) + with gr.Column(): + gr.Markdown("### Visualization Option: (Click Reconstruct to update)") + show_cam = gr.Checkbox(label="Show Camera", value=True) + show_mesh = gr.Checkbox(label="Show Mesh", value=True) + filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False) + filter_white_bg = gr.Checkbox(label="Filter White Background", value=False) + gr.Markdown("### Reconstruction Options: (updated on next run)") + apply_mask_checkbox = gr.Checkbox( + label="Apply mask for predicted ambiguous depth classes & edges", + value=True, + ) + process_res_method_dropdown = gr.Dropdown( + choices=[ + "upper_bound_resize", + "upper_bound_crop", + "lower_bound_resize", + "lower_bound_crop", + ], + value="upper_bound_resize", + label="Image Processing Method", + info="Method for resizing input images", + ) + save_to_gallery_checkbox = gr.Checkbox( + label="Save to Gallery", + value=False, + info="Save current reconstruction results to gallery directory", + ) + gallery_name_input = gr.Textbox( + label="Gallery Name", + placeholder="Enter a name for the gallery folder", + value="", + info="Leave empty for auto-generated name with timestamp", + ) + + return ( + submit_btn, + clear_btn, + frame_filter, + show_cam, + show_mesh, + filter_black_bg, + filter_white_bg, + apply_mask_checkbox, + process_res_method_dropdown, + save_to_gallery_checkbox, + gallery_name_input, + ) + + def create_example_scenes_section(self) -> List[Dict[str, Any]]: + """ + Create the example scenes section. + + Returns: + List of scene information dictionaries + """ + # Get workspace directory from environment variable + workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace") + examples_dir = os.path.join(workspace_dir, "examples") + + # Get scene information + scenes = get_scene_info(examples_dir) + + return scenes + + def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]: + """ + Create the example scene grid. + + Args: + scenes: List of scene information dictionaries + + Returns: + List of scene image components + """ + scene_components = [] + + if scenes: + for i in range(0, len(scenes), 4): # Process 4 scenes per row + with gr.Row(): + for j in range(4): + scene_idx = i + j + if scene_idx < len(scenes): + scene = scenes[scene_idx] + with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]): + # Clickable thumbnail + scene_img = gr.Image( + value=scene["thumbnail"], + height=150, + interactive=False, + show_label=False, + elem_id=f"scene_thumb_{scene['name']}", + sources=[], + ) + scene_components.append(scene_img) + + # Scene name and image count as text below thumbnail + gr.Markdown( + f"**{scene['name']}** \n {scene['num_images']} images", + elem_classes=["scene-info"], + ) + else: + # Empty column to maintain grid structure + with gr.Column(scale=1): + pass + + return scene_components + + def create_header_section(self) -> gr.HTML: + """ + Create the header section with logo and title. + + Returns: + Header HTML component + """ + from depth_anything_3.app.css_and_html import get_header_html + + return gr.HTML(get_header_html(get_logo_base64())) + + def create_description_section(self) -> gr.HTML: + """ + Create the description section. + + Returns: + Description HTML component + """ + from depth_anything_3.app.css_and_html import get_description_html + + return gr.HTML(get_description_html()) + + def create_acknowledgements_section(self) -> gr.HTML: + """ + Create the acknowledgements section. + + Returns: + Acknowledgements HTML component + """ + from depth_anything_3.app.css_and_html import get_acknowledgements_html + + return gr.HTML(get_acknowledgements_html()) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..985abc145753c77143cd87d762ded76ac1dd0755 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/utils.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utility functions for Depth Anything 3 Gradio app. + +This module contains helper functions for data processing, visualization, +and file operations. +""" + + +import json +import os +import shutil +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple +import numpy as np + +def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]: + """ + Create a colored depth visualization. + + Args: + depth: Depth array + + Returns: + Colored depth visualization or None + """ + if depth is None: + return None + + # Normalize depth to 0-1 range + depth_min = depth[depth > 0].min() if (depth > 0).any() else 0 + depth_max = depth.max() + + if depth_max <= depth_min: + return None + + # Normalize depth + depth_norm = (depth - depth_min) / (depth_max - depth_min) + depth_norm = np.clip(depth_norm, 0, 1) + + # Apply colormap (using matplotlib's viridis colormap) + import matplotlib.cm as cm + + # Convert to colored image + depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel + depth_colored = (depth_colored * 255).astype(np.uint8) + + return depth_colored + + +def save_to_gallery_func( + target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None +) -> Tuple[bool, str]: + """ + Save the current reconstruction results to the gallery directory. + + Args: + target_dir: Source directory containing reconstruction results + processed_data: Processed data dictionary + gallery_name: Name for the gallery folder + + Returns: + Tuple of (success, message) + """ + try: + # Get gallery directory from environment variable or use default + gallery_dir = os.environ.get( + "DA3_GALLERY_DIR", + "workspace/gallery", + ) + if not os.path.exists(gallery_dir): + os.makedirs(gallery_dir) + + # Use provided name or create a unique name + if gallery_name is None or gallery_name.strip() == "": + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + gallery_name = f"reconstruction_{timestamp}" + + gallery_path = os.path.join(gallery_dir, gallery_name) + + # Check if directory already exists + if os.path.exists(gallery_path): + return False, f"Save failed: folder '{gallery_name}' already exists" + + # Create the gallery directory + os.makedirs(gallery_path, exist_ok=True) + + # Copy GLB file + glb_source = os.path.join(target_dir, "scene.glb") + glb_dest = os.path.join(gallery_path, "scene.glb") + if os.path.exists(glb_source): + shutil.copy2(glb_source, glb_dest) + + # Copy depth visualization images + depth_vis_dir = os.path.join(target_dir, "depth_vis") + if os.path.exists(depth_vis_dir): + gallery_depth_vis = os.path.join(gallery_path, "depth_vis") + shutil.copytree(depth_vis_dir, gallery_depth_vis) + + # Copy original images + images_source = os.path.join(target_dir, "images") + if os.path.exists(images_source): + gallery_images = os.path.join(gallery_path, "images") + shutil.copytree(images_source, gallery_images) + + scene_preview_source = os.path.join(target_dir, "scene.jpg") + scene_preview_dest = os.path.join(gallery_path, "scene.jpg") + shutil.copy2(scene_preview_source, scene_preview_dest) + + # Save metadata + metadata = { + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), + "num_images": len(processed_data) if processed_data else 0, + "gallery_name": gallery_name, + } + + with open(os.path.join(gallery_path, "metadata.json"), "w") as f: + json.dump(metadata, f, indent=2) + + print(f"Saved reconstruction to gallery: {gallery_path}") + return True, f"Save successful: saved to {gallery_path}" + + except Exception as e: + print(f"Error saving to gallery: {e}") + return False, f"Save failed: {str(e)}" + + +def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]: + """ + Get information about scenes in the examples directory. + + Args: + examples_dir: Path to examples directory + + Returns: + List of scene information dictionaries + """ + import glob + + scenes = [] + if not os.path.exists(examples_dir): + return scenes + + for scene_folder in sorted(os.listdir(examples_dir)): + scene_path = os.path.join(examples_dir, scene_folder) + if os.path.isdir(scene_path): + # Find all image files in the scene folder + image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] + image_files = [] + for ext in image_extensions: + image_files.extend(glob.glob(os.path.join(scene_path, ext))) + image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) + + if image_files: + # Sort images and get the first one for thumbnail + image_files = sorted(image_files) + first_image = image_files[0] + num_images = len(image_files) + + scenes.append( + { + "name": scene_folder, + "path": scene_path, + "thumbnail": first_image, + "num_images": num_images, + "image_files": image_files, + } + ) + + return scenes + + +# NOTE: cleanup was moved to a single canonical helper in +# `depth_anything_3.utils.memory.cleanup_cuda_memory`. +# Callers should import and call that directly instead of using this module. + + +def get_logo_base64() -> Optional[str]: + """ + Convert WAI logo to base64 for embedding in HTML. + + Returns: + Base64 encoded logo string or None + """ + import base64 + + logo_path = "examples/WAI-Logo/wai_logo.png" + try: + with open(logo_path, "rb") as img_file: + img_data = img_file.read() + base64_str = base64.b64encode(img_data).decode() + return f"data:image/png;base64,{base64_str}" + except FileNotFoundError: + return None diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/visualization.py b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..ada49b25532a6ddbeb0e7f99856498e4ce3fb2ad --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/app/modules/visualization.py @@ -0,0 +1,434 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Visualization module for Depth Anything 3 Gradio app. + +This module handles visualization updates, navigation, and measurement functionality. +""" + +import os +from typing import Any, Dict, List, Optional, Tuple +import cv2 +import gradio as gr +import numpy as np + + +class VisualizationHandler: + """ + Handles visualization updates and navigation for the Gradio app. + """ + + def __init__(self): + """Initialize the visualization handler.""" + + def update_view_selectors( + self, processed_data: Optional[Dict[int, Dict[str, Any]]] + ) -> Tuple[gr.Dropdown, gr.Dropdown]: + """ + Update view selector dropdowns based on available views. + + Args: + processed_data: Processed data dictionary + + Returns: + Tuple of (depth_view_selector, measure_view_selector) + """ + if processed_data is None or len(processed_data) == 0: + choices = ["View 1"] + else: + num_views = len(processed_data) + choices = [f"View {i + 1}" for i in range(num_views)] + + return ( + gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector + gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector + ) + + def get_view_data_by_index( + self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int + ) -> Optional[Dict[str, Any]]: + """ + Get view data by index, handling bounds. + + Args: + processed_data: Processed data dictionary + view_index: Index of the view to get + + Returns: + View data dictionary or None + """ + if processed_data is None or len(processed_data) == 0: + return None + + view_keys = list(processed_data.keys()) + if view_index < 0 or view_index >= len(view_keys): + view_index = 0 + + return processed_data[view_keys[view_index]] + + def update_depth_view( + self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int + ) -> Optional[str]: + """ + Update depth view for a specific view index. + + Args: + processed_data: Processed data dictionary + view_index: Index of the view to update + + Returns: + Path to depth visualization image or None + """ + view_data = self.get_view_data_by_index(processed_data, view_index) + if view_data is None or view_data.get("depth_image") is None: + return None + + # Return the depth visualization image directly + return view_data["depth_image"] + + def navigate_depth_view( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + current_selector_value: str, + direction: int, + ) -> Tuple[str, Optional[str]]: + """ + Navigate depth view (direction: -1 for previous, +1 for next). + + Args: + processed_data: Processed data dictionary + current_selector_value: Current selector value + direction: Direction to navigate (-1 for previous, +1 for next) + + Returns: + Tuple of (new_selector_value, depth_vis) + """ + if processed_data is None or len(processed_data) == 0: + return "View 1", None + + # Parse current view number + try: + current_view = int(current_selector_value.split()[1]) - 1 + except: # noqa + current_view = 0 + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + depth_vis = self.update_depth_view(processed_data, new_view) + + return new_selector_value, depth_vis + + def update_measure_view( + self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]: + """ + Update measure view for a specific view index. + + Args: + processed_data: Processed data dictionary + view_index: Index of the view to update + + Returns: + Tuple of (measure_image, depth_right_half, measure_points) + """ + view_data = self.get_view_data_by_index(processed_data, view_index) + if view_data is None: + return None, None, [] # image, depth_right_half, measure_points + + # Get the processed (resized) image + if "image" in view_data and view_data["image"] is not None: + image = view_data["image"].copy() + else: + return None, None, [] + + # Ensure image is in uint8 format + if image.dtype != np.uint8: + if image.max() <= 1.0: + image = (image * 255).astype(np.uint8) + else: + image = image.astype(np.uint8) + + # Extract right half of the depth visualization (pure depth part) + depth_image_path = view_data.get("depth_image", None) + depth_right_half = None + + if depth_image_path and os.path.exists(depth_image_path): + try: + # Load the combined depth visualization image + depth_combined = cv2.imread(depth_image_path) + depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB) + if depth_combined is not None: + height, width = depth_combined.shape[:2] + # Extract right half (depth visualization part) + depth_right_half = depth_combined[:, width // 2 :] + except Exception as e: + print(f"Error extracting depth right half: {e}") + + return image, depth_right_half, [] + + def navigate_measure_view( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + current_selector_value: str, + direction: int, + ) -> Tuple[str, Optional[np.ndarray], Optional[str], List]: + """ + Navigate measure view (direction: -1 for previous, +1 for next). + + Args: + processed_data: Processed data dictionary + current_selector_value: Current selector value + direction: Direction to navigate (-1 for previous, +1 for next) + + Returns: + Tuple of (new_selector_value, measure_image, depth_image_path, measure_points) + """ + if processed_data is None or len(processed_data) == 0: + return "View 1", None, None, [] + + # Parse current view number + try: + current_view = int(current_selector_value.split()[1]) - 1 + except: # noqa + current_view = 0 + + num_views = len(processed_data) + new_view = (current_view + direction) % num_views + + new_selector_value = f"View {new_view + 1}" + measure_image, depth_right_half, measure_points = self.update_measure_view( + processed_data, new_view + ) + + return new_selector_value, measure_image, depth_right_half, measure_points + + def populate_visualization_tabs( + self, processed_data: Optional[Dict[int, Dict[str, Any]]] + ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]: + """ + Populate the depth and measure tabs with processed data. + + Args: + processed_data: Processed data dictionary + + Returns: + Tuple of (depth_vis, measure_img, depth_image_path, measure_points) + """ + if processed_data is None or len(processed_data) == 0: + return None, None, None, [] + + # Use update function to get depth visualization + depth_vis = self.update_depth_view(processed_data, 0) + measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0) + + return depth_vis, measure_img, depth_right_half, [] + + def reset_measure( + self, processed_data: Optional[Dict[int, Dict[str, Any]]] + ) -> Tuple[Optional[np.ndarray], List, str]: + """ + Reset measure points. + + Args: + processed_data: Processed data dictionary + + Returns: + Tuple of (image, measure_points, text) + """ + if processed_data is None or len(processed_data) == 0: + return None, [], "" + + # Return the first view image + first_view = list(processed_data.values())[0] + return first_view["image"], [], "" + + def measure( + self, + processed_data: Optional[Dict[int, Dict[str, Any]]], + measure_points: List, + current_view_selector: str, + event: gr.SelectData, + ) -> List: + """ + Handle measurement on images. + + Args: + processed_data: Processed data dictionary + measure_points: List of current measure points + current_view_selector: Current view selector value + event: Gradio select event + + Returns: + List of [image, depth_right_half, measure_points, text] + """ + try: + print(f"Measure function called with selector: {current_view_selector}") + + if processed_data is None or len(processed_data) == 0: + return [None, [], "No data available"] + + # Use the currently selected view instead of always using the first view + try: + current_view_index = int(current_view_selector.split()[1]) - 1 + except: # noqa + current_view_index = 0 + + print(f"Using view index: {current_view_index}") + + # Get view data safely + if current_view_index < 0 or current_view_index >= len(processed_data): + current_view_index = 0 + + view_keys = list(processed_data.keys()) + current_view = processed_data[view_keys[current_view_index]] + + if current_view is None: + return [None, [], "No view data available"] + + point2d = event.index[0], event.index[1] + print(f"Clicked point: {point2d}") + + measure_points.append(point2d) + + # Get image and depth visualization + image, depth_right_half, _ = self.update_measure_view( + processed_data, current_view_index + ) + if image is None: + return [None, [], "No image available"] + + image = image.copy() + + # Ensure image is in uint8 format for proper cv2 operations + try: + if image.dtype != np.uint8: + if image.max() <= 1.0: + # Image is in [0, 1] range, convert to [0, 255] + image = (image * 255).astype(np.uint8) + else: + # Image is already in [0, 255] range + image = image.astype(np.uint8) + except Exception as e: + print(f"Image conversion error: {e}") + return [None, [], f"Image conversion error: {e}"] + + # Draw circles for points + try: + for p in measure_points: + if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: + image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) + except Exception as e: + print(f"Drawing error: {e}") + return [None, [], f"Drawing error: {e}"] + + # Get depth information from processed_data + depth_text = "" + try: + for i, p in enumerate(measure_points): + if ( + current_view["depth"] is not None + and 0 <= p[1] < current_view["depth"].shape[0] + and 0 <= p[0] < current_view["depth"].shape[1] + ): + d = current_view["depth"][p[1], p[0]] + depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n" + else: + depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501 + except Exception as e: + print(f"Depth text error: {e}") + depth_text = f"Error computing depth: {e}\n" + + if len(measure_points) == 2: + try: + point1, point2 = measure_points + # Draw line + if ( + 0 <= point1[0] < image.shape[1] + and 0 <= point1[1] < image.shape[0] + and 0 <= point2[0] < image.shape[1] + and 0 <= point2[1] < image.shape[0] + ): + image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) + + # Compute 3D distance using depth information and camera intrinsics + distance_text = "- **Distance: Unable to calculate 3D distance**" + if ( + current_view["depth"] is not None + and 0 <= point1[1] < current_view["depth"].shape[0] + and 0 <= point1[0] < current_view["depth"].shape[1] + and 0 <= point2[1] < current_view["depth"].shape[0] + and 0 <= point2[0] < current_view["depth"].shape[1] + ): + try: + # Get depth values at the two points + d1 = current_view["depth"][point1[1], point1[0]] + d2 = current_view["depth"][point2[1], point2[0]] + + # Convert 2D pixel coordinates to 3D world coordinates + if current_view["intrinsics"] is not None: + # Get camera intrinsics + K = current_view["intrinsics"] # 3x3 intrinsic matrix + fx, fy = K[0, 0], K[1, 1] # focal lengths + cx, cy = K[0, 2], K[1, 2] # principal point + + # Convert pixel coordinates to normalized camera coordinates + # Point 1: (u1, v1) -> (x1, y1, z1) + u1, v1 = point1[0], point1[1] + x1 = (u1 - cx) * d1 / fx + y1 = (v1 - cy) * d1 / fy + z1 = d1 + + # Point 2: (u2, v2) -> (x2, y2, z2) + u2, v2 = point2[0], point2[1] + x2 = (u2 - cx) * d2 / fx + y2 = (v2 - cy) * d2 / fy + z2 = d2 + + # Calculate 3D Euclidean distance + p1_3d = np.array([x1, y1, z1]) + p2_3d = np.array([x2, y2, z2]) + distance_3d = np.linalg.norm(p1_3d - p2_3d) + + distance_text = f"- **Distance: {distance_3d:.2f}m**" + else: + # Fallback to simplified calculation if no intrinsics + pixel_distance = np.sqrt( + (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2 + ) + avg_depth = (d1 + d2) / 2 + scale_factor = avg_depth / 1000 # Rough scaling factor + estimated_3d_distance = pixel_distance * scale_factor + distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501 + + except Exception as e: + print(f"Distance computation error: {e}") + distance_text = f"- **Distance computation error: {e}**" + + measure_points = [] + text = depth_text + distance_text + print(f"Measurement complete: {text}") + return [image, depth_right_half, measure_points, text] + except Exception as e: + print(f"Final measurement error: {e}") + return [None, [], f"Measurement error: {e}"] + else: + print(f"Single point measurement: {depth_text}") + return [image, depth_right_half, measure_points, depth_text] + + except Exception as e: + print(f"Overall measure function error: {e}") + return [None, [], f"Measure function error: {e}"] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/cfg.py b/Depth-Anything-3-anysize/src/depth_anything_3/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..4607ff8c0983b67e6ccd2d80b78a163c4e487250 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/cfg.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Configuration utility functions +""" + +import importlib +from pathlib import Path +from typing import Any, Callable, List, Union +from omegaconf import DictConfig, ListConfig, OmegaConf + +try: + OmegaConf.register_new_resolver("eval", eval) +except Exception as e: + # if eval is not available, we can just pass + print(f"Error registering eval resolver: {e}") + + +def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: + """ + Load a configuration. Will resolve inheritance. + Supports both file paths and module paths (e.g., depth_anything_3.configs.giant). + """ + # Check if path is a module path (contains dots but no slashes and doesn't end with .yaml) + if "." in path and "/" not in path and not path.endswith(".yaml"): + # It's a module path, load from package resources + path_parts = path.split(".")[1:] + config_path = Path(__file__).resolve().parent + for part in path_parts: + config_path = config_path.joinpath(part) + config_path = config_path.with_suffix(".yaml") + config = OmegaConf.load(str(config_path)) + else: + # It's a file path (absolute, relative, or with .yaml extension) + config = OmegaConf.load(path) + + if argv is not None: + config_argv = OmegaConf.from_dotlist(argv) + config = OmegaConf.merge(config, config_argv) + config = resolve_recursive(config, resolve_inheritance) + return config + + +def resolve_recursive( + config: Any, + resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], +) -> Any: + config = resolver(config) + if isinstance(config, DictConfig): + for k in config.keys(): + v = config.get(k) + if isinstance(v, (DictConfig, ListConfig)): + config[k] = resolve_recursive(v, resolver) + if isinstance(config, ListConfig): + for i in range(len(config)): + v = config.get(i) + if isinstance(v, (DictConfig, ListConfig)): + config[i] = resolve_recursive(v, resolver) + return config + + +def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: + """ + Recursively resolve inheritance if the config contains: + __inherit__: path/to/parent.yaml or a ListConfig of such paths. + """ + if isinstance(config, DictConfig): + inherit = config.pop("__inherit__", None) + + if inherit: + inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] + + parent_config = None + for parent_path in inherit_list: + assert isinstance(parent_path, str) + parent_config = ( + load_config(parent_path) + if parent_config is None + else OmegaConf.merge(parent_config, load_config(parent_path)) + ) + + if len(config.keys()) > 0: + config = OmegaConf.merge(parent_config, config) + else: + config = parent_config + return config + + +def import_item(path: str, name: str) -> Any: + """ + Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass + """ + return getattr(importlib.import_module(path), name) + + +def create_object(config: DictConfig) -> Any: + """ + Create an object from config. + The config is expected to contains the following: + __object__: + path: path.to.module + name: MyClass + args: as_config | as_params (default to as_config) + """ + config = DictConfig(config) + item = import_item( + path=config.__object__.path, + name=config.__object__.name, + ) + args = config.__object__.get("args", "as_config") + if args == "as_config": + return item(config) + if args == "as_params": + config = OmegaConf.to_object(config) + config.pop("__object__") + return item(**config) + raise NotImplementedError(f"Unknown args type: {args}") + + +def create_dataset(path: str, *args, **kwargs) -> Any: + """ + Create a dataset. Requires the file to contain a "create_dataset" function. + """ + return import_item(path, "create_dataset")(*args, **kwargs) + + +def to_dict_recursive(config_obj): + if isinstance(config_obj, DictConfig): + return {k: to_dict_recursive(v) for k, v in config_obj.items()} + elif isinstance(config_obj, ListConfig): + return [to_dict_recursive(item) for item in config_obj] + return config_obj diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/cli.py b/Depth-Anything-3-anysize/src/depth_anything_3/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..1a7c00ebb6a1381fe7383fa588a2669dd1b99b95 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/cli.py @@ -0,0 +1,748 @@ +# flake8: noqa: E402 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Refactored Depth Anything 3 CLI +Clean, modular command-line interface +""" + +from __future__ import annotations + +import os +from typing import Optional +import typer + +from depth_anything_3.services import start_server +from depth_anything_3.services.gallery import gallery as gallery_main +from depth_anything_3.services.inference_service import run_inference +from depth_anything_3.services.input_handlers import ( + ColmapHandler, + ImageHandler, + ImagesHandler, + InputHandler, + VideoHandler, + parse_export_feat, +) +from depth_anything_3.utils.constants import ( + DEFAULT_EXPORT_DIR, + DEFAULT_GALLERY_DIR, + DEFAULT_GRADIO_DIR, + DEFAULT_MODEL, +) + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + +app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False) + + +# ============================================================================ +# Input type detection utilities +# ============================================================================ + +# Supported file extensions +IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"} +VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"} + + +def detect_input_type(input_path: str) -> str: + """ + Detect input type from path. + + Returns: + - "image": Single image file + - "images": Directory containing images + - "video": Video file + - "colmap": COLMAP directory structure + - "unknown": Cannot determine type + """ + if not os.path.exists(input_path): + return "unknown" + + # Check if it's a file + if os.path.isfile(input_path): + ext = os.path.splitext(input_path)[1].lower() + if ext in IMAGE_EXTENSIONS: + return "image" + elif ext in VIDEO_EXTENSIONS: + return "video" + return "unknown" + + # Check if it's a directory + if os.path.isdir(input_path): + # Check for COLMAP structure + images_dir = os.path.join(input_path, "images") + sparse_dir = os.path.join(input_path, "sparse") + + if os.path.isdir(images_dir) and os.path.isdir(sparse_dir): + return "colmap" + + # Check if directory contains image files + for item in os.listdir(input_path): + item_path = os.path.join(input_path, item) + if os.path.isfile(item_path): + ext = os.path.splitext(item)[1].lower() + if ext in IMAGE_EXTENSIONS: + return "images" + + return "unknown" + + return "unknown" + + +# ============================================================================ +# Common parameters and configuration +# ============================================================================ + +# ============================================================================ +# Inference commands +# ============================================================================ + + +@app.command() +def auto( + input_path: str = typer.Argument( + ..., help="Path to input (image, directory, video, or COLMAP)" + ), + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"), + export_format: str = typer.Option("glb", help="Export format"), + device: str = typer.Option("cuda", help="Device to use"), + use_backend: bool = typer.Option(False, help="Use backend service for inference"), + backend_url: str = typer.Option( + "http://localhost:8008", help="Backend URL (default: http://localhost:8008)" + ), + process_res: Optional[int] = typer.Option( + None, help="Processing resolution; None keeps original size" + ), + process_res_method: str = typer.Option("keep", help="Processing resolution method"), + export_feat: str = typer.Option( + "", + help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').", + ), + auto_cleanup: bool = typer.Option( + False, help="Automatically clean export directory if it exists (no prompt)" + ), + # Video-specific options + fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"), + # COLMAP-specific options + sparse_subdir: str = typer.Option( + "", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)" + ), + align_to_input_ext_scale: bool = typer.Option( + True, help="[COLMAP] Align prediction to input extrinsics scale" + ), + # GLB export options + conf_thresh_percentile: float = typer.Option( + 40.0, help="[GLB] Lower percentile for adaptive confidence threshold" + ), + num_max_points: int = typer.Option( + 1_000_000, help="[GLB] Maximum number of points in the point cloud" + ), + show_cameras: bool = typer.Option( + True, help="[GLB] Show camera wireframes in the exported scene" + ), + # Feat_vis export options + feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"), +): + """ + Automatically detect input type and run appropriate processing. + + Supports: + - Single image file (.jpg, .png, etc.) + - Directory of images + - Video file (.mp4, .avi, etc.) + - COLMAP directory (with 'images' and 'sparse' subdirectories) + """ + # Detect input type + input_type = detect_input_type(input_path) + + if input_type == "unknown": + typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True) + typer.echo("Supported inputs:", err=True) + typer.echo(" - Single image file (.jpg, .png, etc.)", err=True) + typer.echo(" - Directory containing images", err=True) + typer.echo(" - Video file (.mp4, .avi, etc.)", err=True) + typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True) + raise typer.Exit(1) + + # Display detected type + typer.echo(f"πŸ” Detected input type: {input_type.upper()}") + typer.echo(f"πŸ“ Input path: {input_path}") + typer.echo() + + # Determine backend URL based on use_backend flag + final_backend_url = backend_url if use_backend else None + + # Parse export_feat parameter + export_feat_layers = parse_export_feat(export_feat) + + # Route to appropriate handler + if input_type == "image": + typer.echo("Processing single image...") + # Process input + image_files = ImageHandler.process(input_path) + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + elif input_type == "images": + typer.echo("Processing directory of images...") + # Process input - use default extensions + image_files = ImagesHandler.process(input_path, "png,jpg,jpeg") + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + elif input_type == "video": + typer.echo(f"Processing video with FPS={fps}...") + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Process input + image_files = VideoHandler.process(input_path, export_dir, fps) + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + elif input_type == "colmap": + typer.echo( + f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..." + ) + # Process input + image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir) + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + extrinsics=extrinsics, + intrinsics=intrinsics, + align_to_input_ext_scale=align_to_input_ext_scale, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + typer.echo() + typer.echo("βœ… Processing completed successfully!") + + +@app.command() +def image( + image_path: str = typer.Argument(..., help="Path to input image file"), + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"), + export_format: str = typer.Option("glb", help="Export format"), + device: str = typer.Option("cuda", help="Device to use"), + use_backend: bool = typer.Option(False, help="Use backend service for inference"), + backend_url: str = typer.Option( + "http://localhost:8008", help="Backend URL (default: http://localhost:8008)" + ), + process_res: Optional[int] = typer.Option( + None, help="Processing resolution; None keeps original size" + ), + process_res_method: str = typer.Option("keep", help="Processing resolution method"), + export_feat: str = typer.Option( + "", + help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').", + ), + auto_cleanup: bool = typer.Option( + False, help="Automatically clean export directory if it exists (no prompt)" + ), + # GLB export options + conf_thresh_percentile: float = typer.Option( + 40.0, help="[GLB] Lower percentile for adaptive confidence threshold" + ), + num_max_points: int = typer.Option( + 1_000_000, help="[GLB] Maximum number of points in the point cloud" + ), + show_cameras: bool = typer.Option( + True, help="[GLB] Show camera wireframes in the exported scene" + ), + # Feat_vis export options + feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"), +): + """Run camera pose and depth estimation on a single image.""" + # Process input + image_files = ImageHandler.process(image_path) + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Parse export_feat parameter + export_feat_layers = parse_export_feat(export_feat) + + # Determine backend URL based on use_backend flag + final_backend_url = backend_url if use_backend else None + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + +@app.command() +def images( + images_dir: str = typer.Argument(..., help="Path to directory containing input images"), + image_extensions: str = typer.Option( + "png,jpg,jpeg", help="Comma-separated image file extensions to process" + ), + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"), + export_format: str = typer.Option("glb", help="Export format"), + device: str = typer.Option("cuda", help="Device to use"), + use_backend: bool = typer.Option(False, help="Use backend service for inference"), + backend_url: str = typer.Option( + "http://localhost:8008", help="Backend URL (default: http://localhost:8008)" + ), + process_res: Optional[int] = typer.Option( + None, help="Processing resolution; None keeps original size" + ), + process_res_method: str = typer.Option("keep", help="Processing resolution method"), + export_feat: str = typer.Option( + "", + help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').", + ), + auto_cleanup: bool = typer.Option( + False, help="Automatically clean export directory if it exists (no prompt)" + ), + # GLB export options + conf_thresh_percentile: float = typer.Option( + 40.0, help="[GLB] Lower percentile for adaptive confidence threshold" + ), + num_max_points: int = typer.Option( + 1_000_000, help="[GLB] Maximum number of points in the point cloud" + ), + show_cameras: bool = typer.Option( + True, help="[GLB] Show camera wireframes in the exported scene" + ), + # Feat_vis export options + feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"), +): + """Run camera pose and depth estimation on a directory of images.""" + # Process input + image_files = ImagesHandler.process(images_dir, image_extensions) + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Parse export_feat parameter + export_feat_layers = parse_export_feat(export_feat) + + # Determine backend URL based on use_backend flag + final_backend_url = backend_url if use_backend else None + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + +@app.command() +def colmap( + colmap_dir: str = typer.Argument( + ..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories" + ), + sparse_subdir: str = typer.Option( + "", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)" + ), + align_to_input_ext_scale: bool = typer.Option( + True, help="Align prediction to input extrinsics scale" + ), + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"), + export_format: str = typer.Option("glb", help="Export format"), + device: str = typer.Option("cuda", help="Device to use"), + use_backend: bool = typer.Option(False, help="Use backend service for inference"), + backend_url: str = typer.Option( + "http://localhost:8008", help="Backend URL (default: http://localhost:8008)" + ), + process_res: Optional[int] = typer.Option( + None, help="Processing resolution; None keeps original size" + ), + process_res_method: str = typer.Option("keep", help="Processing resolution method"), + export_feat: str = typer.Option( + "", + help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').", + ), + auto_cleanup: bool = typer.Option( + False, help="Automatically clean export directory if it exists (no prompt)" + ), + # GLB export options + conf_thresh_percentile: float = typer.Option( + 40.0, help="[GLB] Lower percentile for adaptive confidence threshold" + ), + num_max_points: int = typer.Option( + 1_000_000, help="[GLB] Maximum number of points in the point cloud" + ), + show_cameras: bool = typer.Option( + True, help="[GLB] Show camera wireframes in the exported scene" + ), + # Feat_vis export options + feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"), +): + """Run pose conditioned depth estimation on COLMAP data.""" + # Process input + image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir) + + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Parse export_feat parameter + export_feat_layers = parse_export_feat(export_feat) + + # Determine backend URL based on use_backend flag + final_backend_url = backend_url if use_backend else None + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + extrinsics=extrinsics, + intrinsics=intrinsics, + align_to_input_ext_scale=align_to_input_ext_scale, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + +@app.command() +def video( + video_path: str = typer.Argument(..., help="Path to input video file"), + fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"), + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"), + export_format: str = typer.Option("glb", help="Export format"), + device: str = typer.Option("cuda", help="Device to use"), + use_backend: bool = typer.Option(False, help="Use backend service for inference"), + backend_url: str = typer.Option( + "http://localhost:8008", help="Backend URL (default: http://localhost:8008)" + ), + process_res: Optional[int] = typer.Option( + None, help="Processing resolution; None keeps original size" + ), + process_res_method: str = typer.Option("keep", help="Processing resolution method"), + export_feat: str = typer.Option( + "", + help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').", + ), + auto_cleanup: bool = typer.Option( + False, help="Automatically clean export directory if it exists (no prompt)" + ), + # GLB export options + conf_thresh_percentile: float = typer.Option( + 40.0, help="[GLB] Lower percentile for adaptive confidence threshold" + ), + num_max_points: int = typer.Option( + 1_000_000, help="[GLB] Maximum number of points in the point cloud" + ), + show_cameras: bool = typer.Option( + True, help="[GLB] Show camera wireframes in the exported scene" + ), + # Feat_vis export options + feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"), +): + """Run depth estimation on video by extracting frames and processing them.""" + # Handle export directory + export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup) + + # Process input + image_files = VideoHandler.process(video_path, export_dir, fps) + + # Parse export_feat parameter + export_feat_layers = parse_export_feat(export_feat) + + # Determine backend URL based on use_backend flag + final_backend_url = backend_url if use_backend else None + + # Run inference + run_inference( + image_paths=image_files, + export_dir=export_dir, + model_dir=model_dir, + device=device, + backend_url=final_backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + + +# ============================================================================ +# Service management commands +# ============================================================================ + + +@app.command() +def backend( + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + device: str = typer.Option("cuda", help="Device to use"), + host: str = typer.Option("127.0.0.1", help="Host to bind to"), + port: int = typer.Option(8008, help="Port to bind to"), + gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"), +): + """Start model backend service with integrated gallery.""" + typer.echo("=" * 60) + typer.echo("πŸš€ Starting Depth Anything 3 Backend Server") + typer.echo("=" * 60) + typer.echo(f"Model directory: {model_dir}") + typer.echo(f"Device: {device}") + + # Check if gallery directory exists + if gallery_dir and os.path.exists(gallery_dir): + typer.echo(f"Gallery directory: {gallery_dir}") + else: + gallery_dir = None # Disable gallery if directory doesn't exist + + typer.echo() + typer.echo("πŸ“‘ Server URLs (Ctrl/CMD+Click to open):") + typer.echo(f" 🏠 Home: http://{host}:{port}") + typer.echo(f" πŸ“Š Dashboard: http://{host}:{port}/dashboard") + typer.echo(f" πŸ“ˆ API Status: http://{host}:{port}/status") + + if gallery_dir: + typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/") + + typer.echo("=" * 60) + + try: + start_server(model_dir, device, host, port, gallery_dir) + except KeyboardInterrupt: + typer.echo("\nπŸ‘‹ Backend server stopped.") + except Exception as e: + typer.echo(f"❌ Failed to start backend: {e}") + raise typer.Exit(1) + + +# ============================================================================ +# Application launch commands +# ============================================================================ + + +@app.command() +def gradio( + model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"), + workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR, help="Workspace directory path"), + gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path"), + host: str = typer.Option("127.0.0.1", help="Host address to bind to"), + port: int = typer.Option(7860, help="Port number to bind to"), + share: bool = typer.Option(False, help="Create a public link for the app"), + debug: bool = typer.Option(False, help="Enable debug mode"), + cache_examples: bool = typer.Option( + False, help="Pre-cache all example scenes at startup for faster loading" + ), + cache_gs_tag: str = typer.Option( + "", + help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", + ), +): + """Launch Depth Anything 3 Gradio interactive web application""" + from depth_anything_3.app.gradio_app import DepthAnything3App + + # Create necessary directories + os.makedirs(workspace_dir, exist_ok=True) + os.makedirs(gallery_dir, exist_ok=True) + + typer.echo("Launching Depth Anything 3 Gradio application...") + typer.echo(f"Model directory: {model_dir}") + typer.echo(f"Workspace directory: {workspace_dir}") + typer.echo(f"Gallery directory: {gallery_dir}") + typer.echo(f"Host: {host}") + typer.echo(f"Port: {port}") + typer.echo(f"Share: {share}") + typer.echo(f"Debug mode: {debug}") + typer.echo(f"Cache examples: {cache_examples}") + if cache_examples: + if cache_gs_tag: + typer.echo( + f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" + ) + else: + typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)") + + try: + # Initialize and launch application + app = DepthAnything3App( + model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir + ) + + # Pre-cache examples if requested + if cache_examples: + typer.echo("\n" + "=" * 60) + typer.echo("Pre-caching mode enabled") + if cache_gs_tag: + typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS") + typer.echo(f"Other scenes will use LOW-RES only") + else: + typer.echo(f"All scenes will use LOW-RES only") + typer.echo("=" * 60) + app.cache_examples( + show_cam=True, + filter_black_bg=False, + filter_white_bg=False, + save_percentage=20.0, + num_max_points=1000, + cache_gs_tag=cache_gs_tag, + gs_trj_mode="smooth", + gs_video_quality="low", + ) + + # Prepare launch arguments + launch_kwargs = {"share": share, "debug": debug} + + app.launch(host=host, port=port, **launch_kwargs) + + except KeyboardInterrupt: + typer.echo("\nGradio application stopped.") + except Exception as e: + typer.echo(f"Failed to launch Gradio application: {e}") + raise typer.Exit(1) + + +@app.command() +def gallery( + gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"), + host: str = typer.Option("127.0.0.1", help="Host address to bind to"), + port: int = typer.Option(8007, help="Port number to bind to"), + open_browser: bool = typer.Option(False, help="Open browser after launch"), +): + """Launch Depth Anything 3 Gallery server""" + + # Validate gallery directory + if not os.path.exists(gallery_dir): + raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}") + + typer.echo("Launching Depth Anything 3 Gallery server...") + typer.echo(f"Gallery directory: {gallery_dir}") + typer.echo(f"Host: {host}") + typer.echo(f"Port: {port}") + typer.echo(f"Auto-open browser: {open_browser}") + + try: + # Set command line arguments + import sys + + sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)] + if open_browser: + sys.argv.append("--open") + + # Launch gallery server + gallery_main() + + except KeyboardInterrupt: + typer.echo("\nGallery server stopped.") + except Exception as e: + typer.echo(f"Failed to launch Gallery server: {e}") + raise typer.Exit(1) + + +if __name__ == "__main__": + app() diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-base.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c52a7e5018388a174841469f9a94dc995e14f220 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-base.yaml @@ -0,0 +1,45 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitb + out_layers: [5, 7, 9, 11] + alt_start: 4 + qknorm_start: 4 + rope_start: 4 + cat_token: True + +head: + __object__: + path: depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 1536 + output_dim: 2 + features: &head_features 128 + out_channels: &head_out_channels [96, 192, 384, 768] + + +cam_enc: + __object__: + path: depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 768 + +cam_dec: + __object__: + path: depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 1536 diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-giant.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-giant.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5a75c043353aa3e5b7c5368e4b26416c2b0b8b0 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-giant.yaml @@ -0,0 +1,71 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitg + out_layers: [19, 27, 33, 39] + alt_start: 13 + qknorm_start: 13 + rope_start: 13 + cat_token: True + +head: + __object__: + path: depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 3072 + output_dim: 2 + features: &head_features 256 + out_channels: &head_out_channels [256, 512, 1024, 1024] + + +cam_enc: + __object__: + path: depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 1536 + +cam_dec: + __object__: + path: depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 3072 + + +gs_head: + __object__: + path: depth_anything_3.model.gsdpt + name: GSDPT + args: as_params + + dim_in: *head_dim_in + output_dim: 38 # should align with gs_adapter's setting, for gs params + features: *head_features + out_channels: *head_out_channels + + +gs_adapter: + __object__: + path: depth_anything_3.model.gs_adapter + name: GaussianAdapter + args: as_params + + sh_degree: 2 + pred_color: false # predict SH coefficient if false + pred_offset_depth: true + pred_offset_xy: true + gaussian_scale_min: 1e-5 + gaussian_scale_max: 30.0 diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-large.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4fa367c9d9b46eb9a62aef7041f68c709eb4c6e3 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-large.yaml @@ -0,0 +1,45 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [11, 15, 19, 23] + alt_start: 8 + qknorm_start: 8 + rope_start: 8 + cat_token: True + +head: + __object__: + path: depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 2048 + output_dim: 2 + features: &head_features 256 + out_channels: &head_out_channels [256, 512, 1024, 1024] + + +cam_enc: + __object__: + path: depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 1024 + +cam_dec: + __object__: + path: depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 2048 diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-small.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10887437697fc9f2614c03add73fe9858b309d91 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3-small.yaml @@ -0,0 +1,45 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vits + out_layers: [5, 7, 9, 11] + alt_start: 4 + qknorm_start: 4 + rope_start: 4 + cat_token: True + +head: + __object__: + path: depth_anything_3.model.dualdpt + name: DualDPT + args: as_params + + dim_in: &head_dim_in 768 + output_dim: 2 + features: &head_features 64 + out_channels: &head_out_channels [48, 96, 192, 384] + + +cam_enc: + __object__: + path: depth_anything_3.model.cam_enc + name: CameraEnc + args: as_params + + dim_out: 384 + +cam_dec: + __object__: + path: depth_anything_3.model.cam_dec + name: CameraDec + args: as_params + + dim_in: 768 diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3metric-large.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3metric-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..124635cfd952c25c8857ee3da63ce0444c4377f2 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3metric-large.yaml @@ -0,0 +1,28 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [4, 11, 17, 23] + alt_start: -1 # -1 means disable + qknorm_start: -1 + rope_start: -1 + cat_token: False + +head: + __object__: + path: depth_anything_3.model.dpt + name: DPT + args: as_params + + dim_in: 1024 + output_dim: 1 + features: 256 + out_channels: [256, 512, 1024, 1024] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3mono-large.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3mono-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..124635cfd952c25c8857ee3da63ce0444c4377f2 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3mono-large.yaml @@ -0,0 +1,28 @@ +__object__: + path: depth_anything_3.model.da3 + name: DepthAnything3Net + args: as_params + +net: + __object__: + path: depth_anything_3.model.dinov2.dinov2 + name: DinoV2 + args: as_params + + name: vitl + out_layers: [4, 11, 17, 23] + alt_start: -1 # -1 means disable + qknorm_start: -1 + rope_start: -1 + cat_token: False + +head: + __object__: + path: depth_anything_3.model.dpt + name: DPT + args: as_params + + dim_in: 1024 + output_dim: 1 + features: 256 + out_channels: [256, 512, 1024, 1024] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3nested-giant-large.yaml b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3nested-giant-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..595c122b1dc976ecfec58b133b2b60d8f724618c --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/configs/da3nested-giant-large.yaml @@ -0,0 +1,10 @@ +__object__: + path: depth_anything_3.model.da3 + name: NestedDepthAnything3Net + args: as_params + +anyview: + __inherit__: depth_anything_3.configs.da3-giant + +metric: + __inherit__: depth_anything_3.configs.da3metric-large diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..57a2a45132eeae8d58a11a26036c54feef9cfe16 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net + +__export__ = [ + NestedDepthAnything3Net, + DepthAnything3Net, +] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_dec.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_dec.py new file mode 100644 index 0000000000000000000000000000000000000000..3353b403683bf556b3823081863573dc7f5f719e --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_dec.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +class CameraDec(nn.Module): + def __init__(self, dim_in=1536): + super().__init__() + output_dim = dim_in + self.backbone = nn.Sequential( + nn.Linear(output_dim, output_dim), + nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.ReLU(), + ) + self.fc_t = nn.Linear(output_dim, 3) + self.fc_qvec = nn.Linear(output_dim, 4) + self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU()) + + def forward(self, feat, camera_encoding=None, *args, **kwargs): + B, N = feat.shape[:2] + feat = feat.reshape(B * N, -1) + feat = self.backbone(feat) + out_t = self.fc_t(feat.float()).reshape(B, N, 3) + if camera_encoding is None: + out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) + out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) + else: + out_qvec = camera_encoding[..., 3:7] + out_fov = camera_encoding[..., -2:] + pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1) + return pose_enc diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_enc.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..bf28e701442fa73d89e54b409800908c138a93d8 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/cam_enc.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from depth_anything_3.model.utils.attention import Mlp +from depth_anything_3.model.utils.block import Block +from depth_anything_3.model.utils.transform import extri_intri_to_pose_encoding +from depth_anything_3.utils.geometry import affine_inverse + + +class CameraEnc(nn.Module): + """ + CameraHead predicts camera parameters from token representations using iterative refinement. + + It applies a series of transformer blocks (the "trunk") to dedicated camera tokens. + """ + + def __init__( + self, + dim_out: int = 1024, + dim_in: int = 9, + trunk_depth: int = 4, + target_dim: int = 9, + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + **kwargs, + ): + super().__init__() + self.target_dim = target_dim + self.trunk_depth = trunk_depth + self.trunk = nn.Sequential( + *[ + Block( + dim=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + init_values=init_values, + ) + for _ in range(trunk_depth) + ] + ) + self.token_norm = nn.LayerNorm(dim_out) + self.trunk_norm = nn.LayerNorm(dim_out) + self.pose_branch = Mlp( + in_features=dim_in, + hidden_features=dim_out // 2, + out_features=dim_out, + drop=0, + ) + + def forward( + self, + ext, + ixt, + image_size, + ) -> tuple: + c2ws = affine_inverse(ext) + pose_encoding = extri_intri_to_pose_encoding( + c2ws, + ixt, + image_size, + ) + pose_tokens = self.pose_branch(pose_encoding) + pose_tokens = self.token_norm(pose_tokens) + pose_tokens = self.trunk(pose_tokens) + pose_tokens = self.trunk_norm(pose_tokens) + return pose_tokens diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/da3.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/da3.py new file mode 100644 index 0000000000000000000000000000000000000000..d0934d8cd84fa22bf1607fcc59e8d03e6aa22b61 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/da3.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn as nn +from addict import Dict +from omegaconf import DictConfig, OmegaConf + +from depth_anything_3.cfg import create_object +from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri +from depth_anything_3.utils.alignment import ( + apply_metric_scaling, + compute_alignment_mask, + compute_sky_mask, + least_squares_scale_scalar, + sample_tensor_for_quantile, + set_sky_regions_to_max_depth, +) +from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity + + +def _wrap_cfg(cfg_obj): + return OmegaConf.create(cfg_obj) + + +class DepthAnything3Net(nn.Module): + """ + Depth Anything 3 network for depth estimation and camera pose estimation. + + This network consists of: + - Backbone: DinoV2 feature extractor + - Head: DPT or DualDPT for depth prediction + - Optional camera decoders for pose estimation + - Optional GSDPT for 3DGS prediction + + Args: + preset: Configuration preset containing network dimensions and settings + + Returns: + Dictionary containing: + - depth: Predicted depth map (B, H, W) + - depth_conf: Depth confidence map (B, H, W) + - extrinsics: Camera extrinsics (B, N, 4, 4) + - intrinsics: Camera intrinsics (B, N, 3, 3) + - gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians + - aux: Auxiliary features for specified layers + """ + + # Patch size for feature extraction + PATCH_SIZE = 14 + + def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None): + """ + Initialize DepthAnything3Net with given yaml-initialized configuration. + """ + super().__init__() + self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net)) + self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head)) + self.cam_dec, self.cam_enc = None, None + if cam_dec is not None: + self.cam_dec = ( + cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec)) + ) + self.cam_enc = ( + cam_enc if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc)) + ) + self.gs_adapter, self.gs_head = None, None + if gs_head is not None and gs_adapter is not None: + self.gs_adapter = ( + gs_adapter + if isinstance(gs_adapter, nn.Module) + else create_object(_wrap_cfg(gs_adapter)) + ) + gs_out_dim = self.gs_adapter.d_in + 1 + if isinstance(gs_head, nn.Module): + assert ( + gs_head.out_dim == gs_out_dim + ), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}" + self.gs_head = gs_head + else: + assert ( + gs_head["output_dim"] == gs_out_dim + ), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}" + self.gs_head = create_object(_wrap_cfg(gs_head)) + + def forward( + self, + x: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = [], + infer_gs: bool = False, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through the network. + + Args: + x: Input images (B, N, 3, H, W) + extrinsics: Camera extrinsics (B, N, 4, 4) - unused + intrinsics: Camera intrinsics (B, N, 3, 3) - unused + feat_layers: List of layer indices to extract features from + + Returns: + Dictionary containing predictions and auxiliary features + """ + # Extract features using backbone + if extrinsics is not None: + with torch.autocast(device_type=x.device.type, enabled=False): + cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:]) + else: + cam_token = None + + feats, aux_feats = self.backbone( + x, cam_token=cam_token, export_feat_layers=export_feat_layers + ) + # feats = [[item for item in feat] for feat in feats] + H, W = x.shape[-2], x.shape[-1] + + # Process features through depth head + with torch.autocast(device_type=x.device.type, enabled=False): + output = self._process_depth_head(feats, H, W) + output = self._process_camera_estimation(feats, H, W, output) + if infer_gs: + output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics) + + # Extract auxiliary features if requested + output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W) + + return output + + def _process_depth_head( + self, feats: list[torch.Tensor], H: int, W: int + ) -> Dict[str, torch.Tensor]: + """Process features through the depth prediction head.""" + return self.head(feats, H, W, patch_start_idx=0) + + def _process_camera_estimation( + self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Process camera pose estimation if camera decoder is available.""" + if self.cam_dec is not None: + pose_enc = self.cam_dec(feats[-1][1]) + # Remove ray information as it's not needed for pose estimation + if "ray" in output: + del output.ray + if "ray_conf" in output: + del output.ray_conf + + # Convert pose encoding to extrinsics and intrinsics + c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W)) + output.extrinsics = affine_inverse(c2w) + output.intrinsics = ixt + + return output + + def _process_gs_head( + self, + feats: list[torch.Tensor], + H: int, + W: int, + output: Dict[str, torch.Tensor], + in_images: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + ) -> Dict[str, torch.Tensor]: + """Process 3DGS parameters estimation if 3DGS head is available.""" + if self.gs_head is None or self.gs_adapter is None: + return output + assert output.get("depth", None) is not None, "must provide MV depth for the GS head." + + # The depth is defined in the DA3 model's camera space, + # so even with provided GT camera poses, + # we instead use the predicted camera poses for better alignment. + ctx_extr = output.get("extrinsics", None) + ctx_intr = output.get("intrinsics", None) + assert ( + ctx_extr is not None and ctx_intr is not None + ), "must process camera info first if GT is not available" + + gt_extr = extrinsics + # homo the extr if needed + ctx_extr = as_homogeneous(ctx_extr) + if gt_extr is not None: + gt_extr = as_homogeneous(gt_extr) + + # forward through the gs_dpt head to get 'camera space' parameters + gs_outs = self.gs_head( + feats=feats, + H=H, + W=W, + patch_start_idx=0, + images=in_images, + ) + raw_gaussians = gs_outs.raw_gs + densities = gs_outs.raw_gs_conf + + # convert to 'world space' 3DGS parameters; ready to export and render + # gt_extr could be None, and will be used to align the pose scale if available + gs_world = self.gs_adapter( + extrinsics=ctx_extr, + intrinsics=ctx_intr, + depths=output.depth, + opacities=map_pdf_to_opacity(densities), + raw_gaussians=raw_gaussians, + image_shape=(H, W), + gt_extrinsics=gt_extr, + ) + output.gaussians = gs_world + + return output + + def _extract_auxiliary_features( + self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int + ) -> Dict[str, torch.Tensor]: + """Extract auxiliary features from specified layers.""" + aux_features = Dict() + assert len(feats) == len(feat_layers) + for feat, feat_layer in zip(feats, feat_layers): + # Reshape features to spatial dimensions + feat_reshaped = feat.reshape( + [ + feat.shape[0], + feat.shape[1], + H // self.PATCH_SIZE, + W // self.PATCH_SIZE, + feat.shape[-1], + ] + ) + aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped + + return aux_features + + +class NestedDepthAnything3Net(nn.Module): + """ + Nested Depth Anything 3 network with metric scaling capabilities. + + This network combines two DepthAnything3Net branches: + - Main branch: Standard depth estimation + - Metric branch: Metric depth estimation for scaling alignment + + The network performs depth alignment using least squares scaling + and handles sky region masking for improved depth estimation. + + Args: + preset: Configuration for the main depth estimation branch + second_preset: Configuration for the metric depth branch + """ + + def __init__(self, anyview: DictConfig, metric: DictConfig): + """ + Initialize NestedDepthAnything3Net with two branches. + + Args: + preset: Configuration for main depth estimation branch + second_preset: Configuration for metric depth branch + """ + super().__init__() + self.da3 = create_object(anyview) + self.da3_metric = create_object(metric) + + def forward( + self, + x: torch.Tensor, + extrinsics: torch.Tensor | None = None, + intrinsics: torch.Tensor | None = None, + export_feat_layers: list[int] | None = [], + infer_gs: bool = False, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass through both branches with metric scaling alignment. + + Args: + x: Input images (B, N, 3, H, W) + extrinsics: Camera extrinsics (B, N, 4, 4) - unused + intrinsics: Camera intrinsics (B, N, 3, 3) - unused + feat_layers: List of layer indices to extract features from + metric_feat: Whether to use metric features (unused) + + Returns: + Dictionary containing aligned depth predictions and camera parameters + """ + # Get predictions from both branches + output = self.da3( + x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs + ) + metric_output = self.da3_metric(x, infer_gs=infer_gs) + + # Apply metric scaling and alignment + output = self._apply_metric_scaling(output, metric_output) + output = self._apply_depth_alignment(output, metric_output) + output = self._handle_sky_regions(output, metric_output) + + return output + + def _apply_metric_scaling( + self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Apply metric scaling to the metric depth output.""" + # Scale metric depth based on camera intrinsics + metric_output.depth = apply_metric_scaling( + metric_output.depth, + output.intrinsics, + ) + return output + + def _apply_depth_alignment( + self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Apply depth alignment using least squares scaling.""" + # Compute non-sky mask + non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) + + # Ensure we have enough non-sky pixels + assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment" + + # Sample depth confidence for quantile computation + depth_conf_ns = output.depth_conf[non_sky_mask] + depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000) + median_conf = torch.quantile(depth_conf_sampled, 0.5) + + # Compute alignment mask + align_mask = compute_alignment_mask( + output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf + ) + + # Compute scale factor using least squares + valid_depth = output.depth[align_mask] + valid_metric_depth = metric_output.depth[align_mask] + scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth) + + # Apply scaling to depth and extrinsics + output.depth *= scale_factor + output.extrinsics[:, :, :3, 3] *= scale_factor + output.is_metric = 1 + output.scale_factor = scale_factor.item() + + return output + + def _handle_sky_regions( + self, + output: Dict[str, torch.Tensor], + metric_output: Dict[str, torch.Tensor], + sky_depth_def: float = 200.0, + ) -> Dict[str, torch.Tensor]: + """Handle sky regions by setting them to maximum depth.""" + non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3) + + # Compute maximum depth for non-sky regions + # Use sampling to safely compute quantile on large tensors + non_sky_depth = output.depth[non_sky_mask] + if non_sky_depth.numel() > 100000: + idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device) + sampled_depth = non_sky_depth[idx] + else: + sampled_depth = non_sky_depth + non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def) + + # Set sky regions to maximum depth and high confidence + output.depth, output.depth_conf = set_sky_regions_to_max_depth( + output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max + ) + + return output diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/dinov2.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..ee0d88bdd6d33edda6c0fb237ee0c77ffc3f6034 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/dinov2.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + + +from typing import List +import torch.nn as nn + +from depth_anything_3.model.dinov2.vision_transformer import ( + vit_base, + vit_giant2, + vit_large, + vit_small, +) + + +class DinoV2(nn.Module): + def __init__( + self, + name: str, + out_layers: List[int], + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = True, + **kwargs, + ): + super().__init__() + assert name in {"vits", "vitb", "vitl", "vitg"} + self.name = name + self.out_layers = out_layers + self.alt_start = alt_start + self.qknorm_start = qknorm_start + self.rope_start = rope_start + self.cat_token = cat_token + encoder_map = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2, + } + encoder_fn = encoder_map[self.name] + ffn_layer = "swiglufused" if self.name == "vitg" else "mlp" + self.pretrained = encoder_fn( + img_size=518, + patch_size=14, + ffn_layer=ffn_layer, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + ) + + def forward(self, x, **kwargs): + return self.pretrained.get_intermediate_layers( + x, + self.out_layers, + **kwargs, + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97dfba90c12b2ffc5f0f1f823b6384c9b4cd6fa2 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# from .attention import MemEffAttention +from .block import Block +from .layer_scale import LayerScale +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .rope import PositionGetter, RotaryPositionEmbedding2D +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused + +__all__ = [ + Mlp, + PatchEmbed, + SwiGLUFFN, + SwiGLUFFNFused, + Block, + # MemEffAttention, + LayerScale, + PositionGetter, + RotaryPositionEmbedding2D, +] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/attention.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..096b9d41ddc95b9c4652597b18f53aee31a573b6 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import torch.nn.functional as F +from torch import Tensor, nn + +logger = logging.getLogger("dinov2") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] + q, k = self.q_norm(q), self.k_norm(k) + if self.rope is not None and pos is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=( + (attn_mask)[:, None].repeat(1, self.num_heads, 1, 1) + if attn_mask is not None + else None + ), + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def _forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/block.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..731519b68b16c8936b765dd1620fbb9c81087f96 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/block.py @@ -0,0 +1,143 @@ +# flake8: noqa: F821 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, Optional +import torch +from torch import Tensor, nn + +from .attention import Attention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +logger = logging.getLogger("dinov2") +XFORMERS_AVAILABLE = True + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ln_eps: float = 1e-6, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim, eps=ln_eps) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim, eps=ln_eps) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + pos=pos, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, + pos: Optional[Tensor] = None, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + if pos is not None: + # if necessary, apply rope to the subset + pos = pos[brange] + residual = residual_func(x_subset, pos=pos) + else: + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/drop_path.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2cc94e969711f1eb9f62093b79a0139b9bfb1e --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/layer_scale.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..898ee12d8b4b65d30d8c041588a8277a8f13d4f2 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/layer_scale.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa: E501 + +from typing import Union +import torch +from torch import Tensor, nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.dim = dim + self.inplace = inplace + self.init_values = init_values + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + def extra_repr(self) -> str: + return f"{self.dim}, init_values={self.init_values}, inplace={self.inplace}" diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/mlp.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..78ad0d8897ddb77e95d2e188a579f6e1d21e3fb5 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/patch_embed.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..64bf6be8994fe52b2fdf1753fdb9f7a691e14980 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/patch_embed.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/rope.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..f75ba37c160cd806a32576e2a1704b3352dec7af --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/rope.py @@ -0,0 +1,200 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +from typing import Dict, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__( + self, batch_size: int, height: int, width: int, device: torch.device + ) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, + tokens: torch.Tensor, + positions: torch.Tensor, + cos_comp: torch.Tensor, + sin_comp: torch.Tensor, + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert ( + positions.ndim == 3 and positions.shape[-1] == 2 + ), "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components( + feature_dim, max_position, tokens.device, tokens.dtype + ) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope( + vertical_features, positions[..., 0], cos_comp, sin_comp + ) + horizontal_features = self._apply_1d_rope( + horizontal_features, positions[..., 1], cos_comp, sin_comp + ) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f58e5b265f74c82a4c40adc3ee33a965e503cf --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/vision_transformer.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..7fdc76521ac042ddb485e79f4fe64ac93bbfa1c8 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dinov2/vision_transformer.py @@ -0,0 +1,437 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import math +from typing import Callable, List, Sequence, Tuple, Union +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange + +from depth_anything_3.utils.logger import logger + +from .layers import LayerScale # noqa: F401 +from .layers import Mlp # noqa: F401 +from .layers import ( # noqa: F401 + Block, + PatchEmbed, + PositionGetter, + RotaryPositionEmbedding2D, + SwiGLUFFNFused, +) + +# logger = logging.getLogger("dinov2") + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1.0, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + alt_start=-1, + qknorm_start=-1, + rope_start=-1, + rope_freq=100, + plus_cam_token=False, + cat_token=True, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating + positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating + positional embeddings + block_prompt: (bool) whether to add ray embeddings to the block input + """ + super().__init__() + self.patch_start_idx = 1 + norm_layer = nn.LayerNorm + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.alt_start = alt_start + self.qknorm_start = qknorm_start + self.rope_start = rope_start + self.cat_token = cat_token + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim + ) + num_patches = self.patch_embed.num_patches + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + if self.alt_start != -1: + self.camera_token = nn.Parameter(torch.randn(1, 2, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) + if num_register_tokens + else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + if self.rope_start != -1: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + else: + self.rope = None + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + qk_norm=i >= qknorm_start if qknorm_start != -1 else False, + rope=self.rope if i >= rope_start and rope_start != -1 else None, + ) + for i in range(depth) + ] + self.blocks = nn.ModuleList(blocks_list) + self.norm = norm_layer(embed_dim) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the + # interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using + # both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_cls_token(self, B, S): + cls_token = self.cls_token.expand(B, S, -1) + cls_token = cls_token.reshape(B * S, -1, self.embed_dim) + return cls_token + + def prepare_tokens_with_masks(self, x, masks=None, cls_token=None, **kwargs): + B, S, nc, w, h = x.shape + x = rearrange(x, "b s c h w -> (b s) c h w") + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + cls_token = self.prepare_cls_token(B, S) + x = torch.cat((cls_token, x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S) + return x + + def _prepare_rope(self, B, S, H, W, device): + pos = None + pos_nodiff = None + if self.rope is not None: + pos = self.position_getter( + B * S, H // self.patch_size, W // self.patch_size, device=device + ) + pos = rearrange(pos, "(b s) n c -> b s n c", b=B) + pos_nodiff = torch.zeros_like(pos).to(pos.dtype) + if self.patch_start_idx > 0: + pos = pos + 1 + pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(device).to(pos.dtype) + pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B) + pos = torch.cat([pos_special, pos], dim=2) + pos_nodiff = pos_nodiff + 1 + pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2) + return pos, pos_nodiff + + def _get_intermediate_layers_not_chunked(self, x, n=1, export_feat_layers=[], **kwargs): + B, S, _, H, W = x.shape + x = self.prepare_tokens_with_masks(x) + output, total_block_len, aux_output = [], len(self.blocks), [] + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device) + + for i, blk in enumerate(self.blocks): + if i < self.rope_start or self.rope is None: + g_pos, l_pos = None, None + else: + g_pos = pos_nodiff + l_pos = pos + if self.alt_start != -1 and i == self.alt_start: + if kwargs.get("cam_token", None) is not None: + logger.info("Using camera conditions provided by the user") + cam_token = kwargs.get("cam_token") + else: + ref_token = self.camera_token[:, :1].expand(B, -1, -1) + src_token = self.camera_token[:, 1:].expand(B, S - 1, -1) + cam_token = torch.cat([ref_token, src_token], dim=1) + x[:, :, 0] = cam_token + + if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1: + x = self.process_attention( + x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None) + ) + else: + x = self.process_attention(x, blk, "local", pos=l_pos) + local_x = x + + if i in blocks_to_take: + out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x + output.append((out_x[:, :, 0], out_x)) + if i in export_feat_layers: + aux_output.append(x) + return output, aux_output + + def process_attention(self, x, block, attn_type="global", pos=None, attn_mask=None): + b, s, n = x.shape[:3] + if attn_type == "local": + x = rearrange(x, "b s n c -> (b s) n c") + if pos is not None: + pos = rearrange(pos, "b s n c -> (b s) n c") + elif attn_type == "global": + x = rearrange(x, "b s n c -> b (s n) c") + if pos is not None: + pos = rearrange(pos, "b s n c -> b (s n) c") + else: + raise ValueError(f"Invalid attention type: {attn_type}") + + x = block(x, pos=pos, attn_mask=attn_mask) + + if attn_type == "local": + x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s) + elif attn_type == "global": + x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s) + return x + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + export_feat_layers: List[int] = [], + **kwargs, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + outputs, aux_outputs = self._get_intermediate_layers_not_chunked( + x, n, export_feat_layers=export_feat_layers, **kwargs + ) + camera_tokens = [out[0] for out in outputs] + if outputs[0][1].shape[-1] == self.embed_dim: + outputs = [self.norm(out[1]) for out in outputs] + elif outputs[0][1].shape[-1] == (self.embed_dim * 2): + outputs = [ + torch.cat( + [out[1][..., : self.embed_dim], self.norm(out[1][..., self.embed_dim :])], + dim=-1, + ) + for out in outputs + ] + else: + raise ValueError(f"Invalid output shape: {outputs[0][1].shape}") + aux_outputs = [self.norm(out) for out in aux_outputs] + outputs = [out[..., 1 + self.num_register_tokens :, :] for out in outputs] + aux_outputs = [out[..., 1 + self.num_register_tokens :, :] for out in aux_outputs] + return tuple(zip(outputs, camera_tokens)), aux_outputs + + +def vit_small(patch_size=16, num_register_tokens=0, depth=12, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=depth, + num_heads=6, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, depth=12, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=depth, + num_heads=12, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, depth=24, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=depth, + num_heads=16, + mlp_ratio=4, + # block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, depth=40, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=depth, + num_heads=24, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dpt.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..337a8a964a7f96fae997c8b12e3ae13e99dbaa58 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dpt.py @@ -0,0 +1,458 @@ +# flake8: noqa E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict as TyDict +from typing import List, Sequence, Tuple +import torch +import torch.nn as nn +from addict import Dict +from einops import rearrange + +from depth_anything_3.model.utils.head_utils import ( + Permute, + create_uv_grid, + custom_interpolate, + position_grid_to_embed, +) + + +class DPT(nn.Module): + """ + DPT for dense prediction (main head + optional sky head, sky always 1 channel). + + Returns: + - Main head: + * If output_dim>1: { head_name, f"{head_name}_conf" } + * If output_dim==1: { head_name } + - Sky head (if use_sky_head=True): { sky_name } # [B, S, 1, H/down_ratio, W/down_ratio] + """ + + def __init__( + self, + dim_in: int, + *, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + # ---- sky head (fixed 1 channel) ---- + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", # 'sigmoid' / 'relu' / 'linear' + use_ln_for_heads: bool = False, # If needed, apply LayerNorm on intermediate features of both heads + norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" + fusion_block_inplace: bool = False, + ) -> None: + super().__init__() + + # -------------------- configuration -------------------- + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + + # Names + self.head_main = head_name + self.sky_name = sky_name + + # Main head: output dimension and confidence switch + self.out_dim = output_dim + self.has_conf = output_dim > 1 + + # Sky head parameters (always 1 channel) + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + + # Fixed 4 intermediate outputs + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + # -------------------- token pre-norm + per-stage projection -------------------- + if norm_type == "layer": + self.norm = nn.LayerNorm(dim_in) + elif norm_type == "idt": + self.norm = nn.Identity() + else: + raise Exception(f"Unknown norm_type {norm_type}, should be 'layer' or 'idt'.") + self.projects = nn.ModuleList( + [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # -------------------- Spatial re-size (align to common scale before fusion) -------------------- + # Design consistent with original: relative to patch grid (x4, x2, x1, /2) + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1), + ] + ) + + # -------------------- scratch: stage adapters + main fusion chain -------------------- + self.scratch = _make_scratch(list(out_channels), features, expand=False) + + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet2 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet3 = _make_fusion_block(features, inplace=fusion_block_inplace) + self.scratch.refinenet4 = _make_fusion_block( + features, has_residual=False, inplace=fusion_block_inplace + ) + + # Heads (shared neck1; then split into two heads) + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + + ln_seq = ( + [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))] + if use_ln_for_heads + else [] + ) + + # Main head + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + # Sky head (fixed 1 channel) + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + nn.Conv2d( + head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1 + ), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + ) + + # ------------------------------------------------------------------------- + # Public forward (supports frame chunking to save memory) + # ------------------------------------------------------------------------- + def forward( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + chunk_size: int = 8, + **kwargs, + ) -> Dict: + """ + Args: + feats: List of 4 entries, each entry is a tensor like [B, S, T, C] (or the 0th element of tuple/list is that tensor). + H, W: Original image dimensions + patch_start_idx: Starting index of patch tokens in sequence (for cropping non-patch tokens) + chunk_size: Chunk size along time dimension S + + Returns: + Dict[str, Tensor] + """ + B, S, N, C = feats[0][0].shape + feats = [feat[0].reshape(B * S, N, C) for feat in feats] + + # update image info, used by the GS-DPT head + extra_kwargs = {} + if "images" in kwargs: + extra_kwargs.update({"images": rearrange(kwargs["images"], "B S ... -> (B S) ...")}) + + if chunk_size is None or chunk_size >= S: + out_dict = self._forward_impl(feats, H, W, patch_start_idx, **extra_kwargs) + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + out_dicts: List[TyDict[str, torch.Tensor]] = [] + for s0 in range(0, S, chunk_size): + s1 = min(s0 + chunk_size, S) + kw = {} + if "images" in extra_kwargs: + kw.update({"images": extra_kwargs["images"][s0:s1]}) + out_dicts.append( + self._forward_impl([f[s0:s1] for f in feats], H, W, patch_start_idx, **kw) + ) + out_dict = {k: torch.cat([od[k] for od in out_dicts], dim=0) for k in out_dicts[0].keys()} + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + ) -> TyDict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] + x = self.norm(x) + # permute -> contiguous before reshape to keep conv input contiguous + x = x.permute(0, 2, 1).contiguous().reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # Align scale + resized_feats.append(x) + + # 2) Fusion pyramid (main branch only) + fused = self._fuse(resized_feats) + + # 3) Upsample to target resolution, optionally add position encoding again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(fused) + fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = self._add_pos_embed(fused, W, H) + + # 4) Shared neck1 + feat = fused + + # 5) Main head: logits -> activation + main_logits = self.scratch.output_conv2(feat) + outs: TyDict[str, torch.Tensor] = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = self._apply_activation_single(fmap[..., :-1], self.activation) + conf = self._apply_activation_single(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(1) + outs[f"{self.head_main}_conf"] = conf.squeeze(1) + else: + outs[self.head_main] = self._apply_activation_single( + main_logits, self.activation + ).squeeze(1) + + # 6) Sky head (fixed 1 channel) + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + outs[self.sky_name] = self._apply_sky_activation(sky_logits).squeeze(1) + + return outs + + # ------------------------------------------------------------------------- + # Subroutines + # ------------------------------------------------------------------------- + def _fuse(self, feats: List[torch.Tensor]) -> torch.Tensor: + """ + 4-layer top-down fusion, returns finest scale features (after fusion, before neck1). + """ + l1, l2, l3, l4 = feats + + l1_rn = self.scratch.layer1_rn(l1) + l2_rn = self.scratch.layer2_rn(l2) + l3_rn = self.scratch.layer3_rn(l3) + l4_rn = self.scratch.layer4_rn(l4) + + # 4 -> 3 -> 2 -> 1 + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + return out + + def _apply_activation_single( + self, x: torch.Tensor, activation: str = "linear" + ) -> torch.Tensor: + """ + Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case. + Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1 + """ + act = activation.lower() if isinstance(activation, str) else activation + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return torch.nn.functional.softplus(x) + if act == "tanh": + return torch.tanh(x) + # Default linear + return x + + def _apply_sky_activation(self, x: torch.Tensor) -> torch.Tensor: + """ + Sky head activation (fixed 1 channel): + * 'sigmoid' -> Sigmoid probability map + * 'relu' -> ReLU positive domain output + * 'linear' -> Original value (logits) + """ + act = ( + self.sky_activation.lower() + if isinstance(self.sky_activation, str) + else self.sky_activation + ) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "relu": + return torch.relu(x) + # 'linear' + return x + + def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Simple UV position encoding directly added to feature map.""" + pw, ph = x.shape[-1], x.shape[-2] + pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pe + + +# ----------------------------------------------------------------------------- +# Building blocks (preserved, consistent with original) +# ----------------------------------------------------------------------------- +def _make_fusion_block( + features: int, + size: Tuple[int, int] = None, + has_residual: bool = True, + groups: int = 1, + inplace: bool = False, +) -> nn.Module: + return FeatureFusionBlock( + features=features, + activation=nn.ReLU(inplace=inplace), + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=size, + has_residual=has_residual, + groups=groups, + ) + + +def _make_scratch( + in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +) -> nn.Module: + scratch = nn.Module() + # Optional expansion by stage + c1 = out_shape + c2 = out_shape * (2 if expand else 1) + c3 = out_shape * (4 if expand else 1) + c4 = out_shape * (8 if expand else 1) + + scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups) + scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups) + return scratch + + +class ResidualConvUnit(nn.Module): + """Lightweight residual convolution block for fusion""" + + def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None: + super().__init__() + self.bn = bn + self.groups = groups + self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) + self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) + self.norm1 = None + self.norm2 = None + self.activation = activation + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + out = self.activation(x) + out = self.conv1(out) + if self.norm1 is not None: + out = self.norm1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.norm2 is not None: + out = self.norm2(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Top-down fusion block: (optional) residual merge + upsampling + 1x1 contraction""" + + def __init__( + self, + features: int, + activation: nn.Module, + deconv: bool = False, + bn: bool = False, + expand: bool = False, + align_corners: bool = True, + size: Tuple[int, int] = None, + has_residual: bool = True, + groups: int = 1, + ) -> None: + super().__init__() + self.align_corners = align_corners + self.size = size + self.has_residual = has_residual + + self.resConfUnit1 = ( + ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None + ) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups) + + out_features = (features // 2) if expand else features + self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override] + """ + xs: + - xs[0]: Top branch input + - xs[1]: Lateral input (can do residual addition with top branch) + """ + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = self.skip_add.add(y, self.resConfUnit1(xs[1])) + + y = self.resConfUnit2(y) + + # Upsampling + if (size is None) and (self.size is None): + up_kwargs = {"scale_factor": 2} + elif size is None: + up_kwargs = {"size": self.size} + else: + up_kwargs = {"size": size} + + y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) + y = self.out_conv(y) + return y diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/dualdpt.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/dualdpt.py new file mode 100644 index 0000000000000000000000000000000000000000..e84c5cf55a4698b3df8bff1eda7306778dd63a52 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/dualdpt.py @@ -0,0 +1,488 @@ +# flake8: noqa E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Sequence, Tuple +import torch +import torch.nn as nn +from addict import Dict + +from depth_anything_3.model.dpt import _make_fusion_block, _make_scratch +from depth_anything_3.model.utils.head_utils import ( + Permute, + create_uv_grid, + custom_interpolate, + position_grid_to_embed, +) + + +class DualDPT(nn.Module): + """ + Dual-head DPT for dense prediction with an always-on auxiliary head. + + Architectural notes: + - Sky/object branches are removed. + - `intermediate_layer_idx` is fixed to (0, 1, 2, 3). + - Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing). + - Auxiliary head is internally multi-level; **only the final level** is returned. + - Returns a **dict** with keys from `head_names`, e.g.: + { main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" } + - `feature_only` is fixed to False. + """ + + def __init__( + self, + dim_in: int, + *, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + ) -> None: + super().__init__() + + # -------------------- configuration -------------------- + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + + # names ONLY come from config (no hard-coded strings elsewhere) + self.head_main, self.head_aux = head_names + + # Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3) + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + # -------------------- token pre-norm + per-stage projection -------------------- + self.norm = nn.LayerNorm(dim_in) + self.projects = nn.ModuleList( + [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels] + ) + + # -------------------- spatial re-sizers (align to common scale before fusion) -------------------- + # design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale + self.resize_layers = nn.ModuleList( + [ + nn.ConvTranspose2d( + out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0 + ), + nn.ConvTranspose2d( + out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0 + ), + nn.Identity(), + nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1), + ] + ) + + # -------------------- scratch: stage adapters + fusion (main & aux are separate) -------------------- + self.scratch = _make_scratch(list(out_channels), features, expand=False) + + # Main fusion chain (independent) + self.scratch.refinenet1 = _make_fusion_block(features) + self.scratch.refinenet2 = _make_fusion_block(features) + self.scratch.refinenet3 = _make_fusion_block(features) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False) + + # Primary head neck + head (independent) + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = nn.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1 + ) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0), + ) + + # Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False") + self.scratch.refinenet1_aux = _make_fusion_block(features) + self.scratch.refinenet2_aux = _make_fusion_block(features) + self.scratch.refinenet3_aux = _make_fusion_block(features) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False) + + # Aux pre-head per level (we will only *return final level*) + self.scratch.output_conv1_aux = nn.ModuleList( + [self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)] + ) + + # Aux final projection per level + use_ln = True + ln_seq = ( + [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))] + if use_ln + else [] + ) + self.scratch.output_conv2_aux = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d( + head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1 + ), + *ln_seq, + nn.ReLU(inplace=True), + nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0), + ) + for _ in range(self.aux_levels) + ] + ) + + # ------------------------------------------------------------------------- + # Public forward (supports frame chunking for memory) + # ------------------------------------------------------------------------- + + def forward( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + chunk_size: int = 8, + ) -> Dict[str, torch.Tensor]: + """ + Args: + aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer. + images: [B, S, 3, H, W], in [0, 1]. + patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens). + frames_chunk_size: Optional chunking along S for memory. + + Returns: + Dict[str, Tensor] with keys based on `head_names`, e.g.: + self.head_main, f"{self.head_main}_conf", + self.head_aux, f"{self.head_aux}_conf" + Shapes: + main: [B, S, out_dim, H/down_ratio, W/down_ratio] + main_cf: [B, S, 1, H/down_ratio, W/down_ratio] + aux: [B, S, 7, H/down_ratio, W/down_ratio] + aux_cf: [B, S, 1, H/down_ratio, W/down_ratio] + """ + B, S, N, C = feats[0][0].shape + feats = [feat[0].reshape(B * S, N, C) for feat in feats] + if chunk_size is None or chunk_size >= S: + out_dict = self._forward_impl(feats, H, W, patch_start_idx) + out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + out_dicts = [] + for s0 in range(0, S, chunk_size): + s1 = min(s0 + chunk_size, S) + out_dict = self._forward_impl( + [feat[s0:s1] for feat in feats], + H, + W, + patch_start_idx, + ) + out_dicts.append(out_dict) + out_dict = { + k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0) + for k in out_dicts[0].keys() + } + out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()} + return Dict(out_dict) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + ) -> Dict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # align scales + resized_feats.append(x) + + # 2) Fuse pyramid (main & aux are completely independent) + fused_main, fused_aux_pyr = self._fuse(resized_feats) + + # 3) Upsample to target resolution and (optional) add pos-embed again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused_main = custom_interpolate( + fused_main, (h_out, w_out), mode="bilinear", align_corners=True + ) + if self.pos_embed: + fused_main = self._add_pos_embed(fused_main, W, H) + + # Primary head: conv1 -> conv2 -> activate + # fused_main = self.scratch.output_conv1(fused_main) + main_logits = self.scratch.output_conv2(fused_main) + fmap = main_logits.permute(0, 2, 3, 1) + main_pred = self._apply_activation_single(fmap[..., :-1], self.activation) + main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation) + + # Auxiliary head (multi-level inside) -> only last level returned (after activation) + last_aux = fused_aux_pyr[-1] + if self.pos_embed: + last_aux = self._add_pos_embed(last_aux, W, H) + # neck (per-level pre-conv) then final projection (only for last level) + # last_aux = self.scratch.output_conv1_aux[-1](last_aux) + last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) + fmap_last = last_aux_logits.permute(0, 2, 3, 1) + aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear") + aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation) + return { + self.head_main: main_pred.squeeze(-1), + f"{self.head_main}_conf": main_conf, + self.head_aux: aux_pred, + f"{self.head_aux}_conf": aux_conf, + } + + # ------------------------------------------------------------------------- + # Subroutines + # ------------------------------------------------------------------------- + + def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Feature pyramid fusion. + Returns: + fused_main: Tensor at finest scale (after refinenet1) + aux_pyr: List of aux tensors at each level (pre out_conv1_aux) + """ + l1, l2, l3, l4 = feats + + l1_rn = self.scratch.layer1_rn(l1) + l2_rn = self.scratch.layer2_rn(l2) + l3_rn = self.scratch.layer3_rn(l3) + l4_rn = self.scratch.layer4_rn(l4) + + # level 4 -> 3 + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) + aux_list: List[torch.Tensor] = [] + if self.aux_levels >= 4: + aux_list.append(aux_out) + + # level 3 -> 2 + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:]) + if self.aux_levels >= 3: + aux_list.append(aux_out) + + # level 2 -> 1 + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:]) + if self.aux_levels >= 2: + aux_list.append(aux_out) + + # level 1 (final) + out = self.scratch.refinenet1(out, l1_rn) + aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn) + aux_list.append(aux_out) + + out = self.scratch.output_conv1(out) + aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)] + + return out, aux_list + + def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Simple UV positional embedding added to feature maps.""" + pw, ph = x.shape[-1], x.shape[-2] + pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1) + return x + pe + + def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential: + """Factory for the aux pre-head stack before the final 1x1 projection.""" + if self.aux_out1_conv_num == 5: + return nn.Sequential( + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + ) + if self.aux_out1_conv_num == 3: + return nn.Sequential( + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1), + nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1), + ) + if self.aux_out1_conv_num == 1: + return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1)) + raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported") + + def _apply_activation_single( + self, x: torch.Tensor, activation: str = "linear" + ) -> torch.Tensor: + """ + Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case. + Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1 + """ + act = activation.lower() if isinstance(activation, str) else activation + if act == "exp": + return torch.exp(x) + if act == "expm1": + return torch.expm1(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return torch.nn.functional.softplus(x) + if act == "tanh": + return torch.tanh(x) + # Default linear + return x + + +# # ----------------------------------------------------------------------------- +# # Building blocks (tidy) +# # ----------------------------------------------------------------------------- + + +# def _make_fusion_block( +# features: int, +# size: Tuple[int, int] = None, +# has_residual: bool = True, +# groups: int = 1, +# inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace" +# ) -> nn.Module: +# return FeatureFusionBlock( +# features=features, +# activation=nn.ReLU(inplace=inplace), +# deconv=False, +# bn=False, +# expand=False, +# align_corners=True, +# size=size, +# has_residual=has_residual, +# groups=groups, +# ) + + +# def _make_scratch( +# in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False +# ) -> nn.Module: +# scratch = nn.Module() +# # optionally expand widths by stage +# c1 = out_shape +# c2 = out_shape * (2 if expand else 1) +# c3 = out_shape * (4 if expand else 1) +# c4 = out_shape * (8 if expand else 1) + +# scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups) +# scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups) +# scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups) +# scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups) +# return scratch + + +# class ResidualConvUnit(nn.Module): +# """Lightweight residual conv block used within fusion.""" + +# def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None: +# super().__init__() +# self.bn = bn +# self.groups = groups +# self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) +# self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups) +# self.norm1 = None +# self.norm2 = None +# self.activation = activation +# self.skip_add = nn.quantized.FloatFunctional() + +# def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] +# out = self.activation(x) +# out = self.conv1(out) +# if self.norm1 is not None: +# out = self.norm1(out) + +# out = self.activation(out) +# out = self.conv2(out) +# if self.norm2 is not None: +# out = self.norm2(out) + +# return self.skip_add.add(out, x) + + +# class FeatureFusionBlock(nn.Module): +# """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink.""" + +# def __init__( +# self, +# features: int, +# activation: nn.Module, +# deconv: bool = False, +# bn: bool = False, +# expand: bool = False, +# align_corners: bool = True, +# size: Tuple[int, int] = None, +# has_residual: bool = True, +# groups: int = 1, +# ) -> None: +# super().__init__() +# self.align_corners = align_corners +# self.size = size +# self.has_residual = has_residual + +# self.resConfUnit1 = ( +# ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None +# ) +# self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups) + +# out_features = (features // 2) if expand else features +# self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups) +# self.skip_add = nn.quantized.FloatFunctional() + +# def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override] +# """ +# xs: +# - xs[0]: top input +# - xs[1]: (optional) lateral (to be added with residual) +# """ +# y = xs[0] +# if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: +# y = self.skip_add.add(y, self.resConfUnit1(xs[1])) + +# y = self.resConfUnit2(y) + +# # upsample +# if (size is None) and (self.size is None): +# up_kwargs = {"scale_factor": 2} +# elif size is None: +# up_kwargs = {"size": self.size} +# else: +# up_kwargs = {"size": size} + +# y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) +# y = self.out_conv(y) +# return y diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/gs_adapter.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/gs_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..1335ce387476c339fdd7a8b3cb5349dfce56a7e0 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/gs_adapter.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional +import torch +from einops import einsum, rearrange, repeat +from torch import nn + +from depth_anything_3.model.utils.transform import cam_quat_xyzw_to_world_quat_wxyz +from depth_anything_3.specs import Gaussians +from depth_anything_3.utils.geometry import affine_inverse, get_world_rays, sample_image_grid +from depth_anything_3.utils.pose_align import batch_align_poses_umeyama +from depth_anything_3.utils.sh_helpers import rotate_sh + + +class GaussianAdapter(nn.Module): + + def __init__( + self, + sh_degree: int = 0, + pred_color: bool = False, + pred_offset_depth: bool = False, + pred_offset_xy: bool = True, + gaussian_scale_min: float = 1e-5, + gaussian_scale_max: float = 30.0, + ): + super().__init__() + self.sh_degree = sh_degree + self.pred_color = pred_color + self.pred_offset_depth = pred_offset_depth + self.pred_offset_xy = pred_offset_xy + self.gaussian_scale_min = gaussian_scale_min + self.gaussian_scale_max = gaussian_scale_max + + # Create a mask for the spherical harmonics coefficients. This ensures that at + # initialization, the coefficients are biased towards having a large DC + # component and small view-dependent components. + if not pred_color: + self.register_buffer( + "sh_mask", + torch.ones((self.d_sh,), dtype=torch.float32), + persistent=False, + ) + for degree in range(1, sh_degree + 1): + self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree + + def forward( + self, + extrinsics: torch.Tensor, # "*#batch 4 4" + intrinsics: torch.Tensor, # "*#batch 3 3" + depths: torch.Tensor, # "*#batch" + opacities: torch.Tensor, # "*#batch" | "*#batch _" + raw_gaussians: torch.Tensor, # "*#batch _" + image_shape: tuple[int, int], + eps: float = 1e-8, + gt_extrinsics: Optional[torch.Tensor] = None, # "*#batch 4 4" + **kwargs, + ) -> Gaussians: + device = extrinsics.device + dtype = raw_gaussians.dtype + H, W = image_shape + b, v = raw_gaussians.shape[:2] + + # get cam2worlds and intr_normed to adapt to 3DGS codebase + cam2worlds = affine_inverse(extrinsics) + intr_normed = intrinsics.clone().detach() + intr_normed[..., 0, :] /= W + intr_normed[..., 1, :] /= H + + # 1. compute 3DGS means + # 1.1) offset the predicted depth if needed + if self.pred_offset_depth: + gs_depths = depths + raw_gaussians[..., -1] + raw_gaussians = raw_gaussians[..., :-1] + else: + gs_depths = depths + # 1.2) align predicted poses with GT if needed + if gt_extrinsics is not None and not torch.equal(extrinsics, gt_extrinsics): + try: + _, _, pose_scales = batch_align_poses_umeyama( + gt_extrinsics.detach().float(), + extrinsics.detach().float(), + ) + except Exception: + pose_scales = torch.ones_like(extrinsics[:, 0, 0, 0]) + pose_scales = torch.clamp(pose_scales, min=1 / 3.0, max=3.0) + cam2worlds[:, :, :3, 3] = cam2worlds[:, :, :3, 3] * rearrange( + pose_scales, "b -> b () ()" + ) # [b, i, j] + gs_depths = gs_depths * rearrange(pose_scales, "b -> b () () ()") # [b, v, h, w] + # 1.3) casting xy in image space + xy_ray, _ = sample_image_grid((H, W), device) + xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # b v h w xy + # offset xy if needed + if self.pred_offset_xy: + pixel_size = 1 / torch.tensor((W, H), dtype=xy_ray.dtype, device=device) + offset_xy = raw_gaussians[..., :2] + xy_ray = xy_ray + offset_xy * pixel_size + raw_gaussians = raw_gaussians[..., 2:] # skip the offset_xy + # 1.4) unproject depth + xy to world ray + origins, directions = get_world_rays( + xy_ray, + repeat(cam2worlds, "b v i j -> b v h w i j", h=H, w=W), + repeat(intr_normed, "b v i j -> b v h w i j", h=H, w=W), + ) + gs_means_world = origins + directions * gs_depths[..., None] + gs_means_world = rearrange(gs_means_world, "b v h w d -> b (v h w) d") + + # 2. compute other GS attributes + scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1) + + # 2.1) 3DGS scales + # make the scale invarient to resolution + scale_min = self.gaussian_scale_min + scale_max = self.gaussian_scale_max + scales = scale_min + (scale_max - scale_min) * scales.sigmoid() + pixel_size = 1 / torch.tensor((W, H), dtype=dtype, device=device) + multiplier = self.get_scale_multiplier(intr_normed, pixel_size) + gs_scales = scales * gs_depths[..., None] * multiplier[..., None, None, None] + gs_scales = rearrange(gs_scales, "b v h w d -> b (v h w) d") + + # 2.2) 3DGS quaternion (world space) + # due to historical issue, assume quaternion in order xyzw, not wxyz + # Normalize the quaternion features to yield a valid quaternion. + rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps) + # rotate them to world space + cam_quat_xyzw = rearrange(rotations, "b v h w c -> b (v h w) c") + c2w_mat = repeat( + cam2worlds, + "b v i j -> b (v h w) i j", + h=H, + w=W, + ) + world_quat_wxyz = cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w_mat) + gs_rotations_world = world_quat_wxyz # b (v h w) c + + # 2.3) 3DGS color / SH coefficient (world space) + sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3) + if not self.pred_color: + sh = sh * self.sh_mask + + if self.pred_color or self.sh_degree == 0: + # predict pre-computed color or predict only DC band, no need to transform + gs_sh_world = sh + else: + gs_sh_world = rotate_sh(sh, cam2worlds[:, :, None, None, None, :3, :3]) + gs_sh_world = rearrange(gs_sh_world, "b v h w xyz d_sh -> b (v h w) xyz d_sh") + + # 2.4) 3DGS opacity + gs_opacities = rearrange(opacities, "b v h w ... -> b (v h w) ...") + + return Gaussians( + means=gs_means_world, + harmonics=gs_sh_world, + opacities=gs_opacities, + scales=gs_scales, + rotations=gs_rotations_world, + ) + + def get_scale_multiplier( + self, + intrinsics: torch.Tensor, # "*#batch 3 3" + pixel_size: torch.Tensor, # "*#batch 2" + multiplier: float = 0.1, + ) -> torch.Tensor: # " *batch" + xy_multipliers = multiplier * einsum( + intrinsics[..., :2, :2].float().inverse().to(intrinsics), + pixel_size, + "... i j, j -> ... i", + ) + return xy_multipliers.sum(dim=-1) + + @property + def d_sh(self) -> int: + return 1 if self.pred_color else (self.sh_degree + 1) ** 2 + + @property + def d_in(self) -> int: + # provided as reference to the gs_dpt output dim + raw_gs_dim = 0 + if self.pred_offset_xy: + raw_gs_dim += 2 + raw_gs_dim += 3 # scales + raw_gs_dim += 4 # quaternion + raw_gs_dim += 3 * self.d_sh # color + if self.pred_offset_depth: + raw_gs_dim += 1 + + return raw_gs_dim diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/gsdpt.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/gsdpt.py new file mode 100644 index 0000000000000000000000000000000000000000..70448b7235982a3b47c0c07f15edd4c38677e72e --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/gsdpt.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict as TyDict +from typing import List, Sequence +import torch +import torch.nn as nn + +from depth_anything_3.model.dpt import DPT +from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate + + +class GSDPT(DPT): + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 4, + activation: str = "linear", + conf_activation: str = "sigmoid", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + feature_only: bool = False, + down_ratio: int = 1, + conf_dim: int = 1, + norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer" + fusion_block_inplace: bool = False, + ) -> None: + super().__init__( + dim_in=dim_in, + patch_size=patch_size, + output_dim=output_dim, + activation=activation, + conf_activation=conf_activation, + features=features, + out_channels=out_channels, + pos_embed=pos_embed, + down_ratio=down_ratio, + head_name="raw_gs", + use_sky_head=False, + norm_type=norm_type, + fusion_block_inplace=fusion_block_inplace, + ) + self.conf_dim = conf_dim + if conf_dim and conf_dim > 1: + assert ( + conf_activation == "linear" + ), "use linear prediction when using view-dependent opacity" + + merger_out_dim = features if feature_only else features // 2 + self.images_merger = nn.Sequential( + nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first + nn.GELU(), + nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1), + nn.GELU(), + nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1), + nn.GELU(), + ) + + # ------------------------------------------------------------------------- + # Internal forward (single chunk) + # ------------------------------------------------------------------------- + def _forward_impl( + self, + feats: List[torch.Tensor], + H: int, + W: int, + patch_start_idx: int, + images: torch.Tensor, + ) -> TyDict[str, torch.Tensor]: + B, _, C = feats[0].shape + ph, pw = H // self.patch_size, W // self.patch_size + resized_feats = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C] + x = self.norm(x) + x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw] + + x = self.projects[stage_idx](x) + if self.pos_embed: + x = self._add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) # Align scale + resized_feats.append(x) + + # 2) Fusion pyramid (main branch only) + fused = self._fuse(resized_feats) + fused = self.scratch.output_conv1(fused) + + # 3) Upsample to target resolution, optionally add position encoding again + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + + # inject the image information here + fused = fused + self.images_merger(images) + + if self.pos_embed: + fused = self._add_pos_embed(fused, W, H) + + # 4) Shared neck1 + # feat = self.scratch.output_conv1(fused) + feat = fused + + # 5) Main head: logits -> activate_head or single channel activation + main_logits = self.scratch.output_conv2(feat) + outs: TyDict[str, torch.Tensor] = {} + if self.has_conf: + pred, conf = activate_head_gs( + main_logits, + activation=self.activation, + conf_activation=self.conf_activation, + conf_dim=self.conf_dim, + ) + outs[self.head_main] = pred.squeeze(1) + outs[f"{self.head_main}_conf"] = conf.squeeze(1) + else: + outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1) + + return outs diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/attention.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..49c07a8c5e5c65fc8e0700aab9ae660552dba610 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/attention.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa + +from typing import Callable, Optional, Union +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + # Debug breakpoint removed for production + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + q = self.rope(q, pos) if self.rope is not None else q + k = self.rope(k, pos) if self.rope is not None else k + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + attn_mask=attn_mask, + ) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/block.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/block.py new file mode 100644 index 0000000000000000000000000000000000000000..993fb4c0bdc02bd0976a6471adf1a15452d0f5c0 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/block.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable +from torch import Tensor, nn + +from .attention import Attention, LayerScale, Mlp + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + qk_norm: bool = False, + rope=None, + ) -> None: + super().__init__() + + self.norm1 = norm_layer(dim) + + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + qk_norm=qk_norm, + rope=rope, + ) + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + + self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0 + + def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor: + def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + # drop_path is always 0, so always take the else branch + x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask) + x = x + ffn_residual_func(x) + return x diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/gs_renderer.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/gs_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..91414d38e84cc030af5207682d06135eec72b125 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/gs_renderer.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from math import isqrt +from typing import Literal, Optional +import torch +from einops import rearrange, repeat +from tqdm import tqdm + +from depth_anything_3.specs import Gaussians +from depth_anything_3.utils.camera_trj_helpers import ( + interpolate_extrinsics, + interpolate_intrinsics, + render_dolly_zoom_path, + render_stabilization_path, + render_wander_path, + render_wobble_inter_path, +) +from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov +from depth_anything_3.utils.logger import logger + +try: + from gsplat import rasterization +except ImportError: + logger.warn( + "Dependency `gsplat` is required for rendering 3DGS. " + "Install via: pip install git+https://github.com/nerfstudio-project/" + "gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70" + ) + + +def render_3dgs( + extrinsics: torch.Tensor, # "batch_views 4 4", w2c + intrinsics: torch.Tensor, # "batch_views 3 3", normalized + image_shape: tuple[int, int], + gaussian: Gaussians, + background_color: Optional[torch.Tensor] = None, # "batch_views 3" + use_sh: bool = True, + num_view: int = 1, + color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D", + **kwargs, +) -> tuple[ + torch.Tensor, # "batch_views 3 height width" + torch.Tensor, # "batch_views height width" +]: + # extract gaussian params + gaussian_means = gaussian.means + gaussian_scales = gaussian.scales + gaussian_quats = gaussian.rotations + gaussian_opacities = gaussian.opacities + gaussian_sh_coefficients = gaussian.harmonics + b, _, _ = extrinsics.shape + + if background_color is None: + background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to( + gaussian_sh_coefficients + ) + + if use_sh: + _, _, _, n = gaussian_sh_coefficients.shape + degree = isqrt(n) - 1 + shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous() + else: # use color + shs = ( + gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous() + ) # (b, g, c), normed to (0, 1) + + h, w = image_shape + + fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1) + tan_fov_x = (0.5 * fov_x).tan() + tan_fov_y = (0.5 * fov_y).tan() + focal_length_x = w / (2 * tan_fov_x) + focal_length_y = h / (2 * tan_fov_y) + + view_matrix = extrinsics.float() + + all_images = [] + all_radii = [] + all_depths = [] + # render view in a batch based, each batch contains one scene + # assume the Gaussian parameters are originally repeated along the view dim + batch_scene = b // num_view + + def index_i_gs_attr(full_attr, idx): + # return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0] + return full_attr[idx] + + for i in range(batch_scene): + K = repeat( + torch.tensor( + [ + [0, 0, w / 2.0], + [0, 0, h / 2.0], + [0, 0, 1], + ] + ), + "i j -> v i j", + v=num_view, + ).to(gaussian_means) + K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i] + K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i] + + i_means = index_i_gs_attr(gaussian_means, i) # [N, 3] + i_scales = index_i_gs_attr(gaussian_scales, i) + i_quats = index_i_gs_attr(gaussian_quats, i) + i_opacities = index_i_gs_attr(gaussian_opacities, i) # [N,] + i_colors = index_i_gs_attr(shs, i) # [N, K, 3] + i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] # [v, 4, 4] + i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[ + i + ] # [v, 3] + + render_colors, render_alphas, info = rasterization( + means=i_means, + quats=i_quats, # [N, 4] + scales=i_scales, # [N, 3] + opacities=i_opacities, + colors=i_colors, + viewmats=i_viewmats, # [v, 4, 4] + Ks=K, # [v, 3, 3] + backgrounds=i_backgrounds, + render_mode=color_mode, + width=w, + height=h, + packed=False, + sh_degree=degree if use_sh else None, + ) + depth = render_colors[..., -1].unbind(dim=0) + + image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0) + radii = info["radii"].unbind(dim=0) + try: + info["means2d"].retain_grad() # [1, N, 2] + except Exception: + pass + all_images.extend(image) + all_depths.extend(depth) + all_radii.extend(radii) + + return torch.stack(all_images), torch.stack(all_depths) + + +def run_renderer_in_chunk_w_trj_mode( + gaussians: Gaussians, + extrinsics: torch.Tensor, # world2cam, "batch view 4 4" | "batch view 3 4" + intrinsics: torch.Tensor, # unnormed intrinsics, "batch view 3 3" + image_shape: tuple[int, int], + chunk_size: Optional[int] = 8, + trj_mode: Literal[ + "original", + "smooth", + "interpolate", + "interpolate_smooth", + "wander", + "dolly_zoom", + "extend", + "wobble_inter", + ] = "smooth", + input_shape: Optional[tuple[int, int]] = None, + enable_tqdm: Optional[bool] = False, + **kwargs, +) -> tuple[ + torch.Tensor, # color, "batch view 3 height width" + torch.Tensor, # depth, "batch view height width" +]: + cam2world = affine_inverse(as_homogeneous(extrinsics)) + if input_shape is not None: + in_h, in_w = input_shape + else: + in_h, in_w = image_shape + intr_normed = intrinsics.clone().detach() + intr_normed[..., 0, :] /= in_w + intr_normed[..., 1, :] /= in_h + if extrinsics.shape[1] <= 1: + assert trj_mode in [ + "wander", + "dolly_zoom", + ], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1" + + def _smooth_trj_fn_batch(raw_c2ws, k_size=50): + try: + smooth_c2ws = torch.stack( + [render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws], + dim=0, + ) + except Exception as e: + print(f"[DEBUG] Path smoothing failed with error: {e}.") + smooth_c2ws = raw_c2ws + return smooth_c2ws + + # get rendered trj + if trj_mode == "original": + tgt_c2w = cam2world + tgt_intr = intr_normed + elif trj_mode == "smooth": + tgt_c2w = _smooth_trj_fn_batch(cam2world) + tgt_intr = intr_normed + elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]: + inter_len = 8 + total_len = (cam2world.shape[1] - 1) * inter_len + if total_len > 24 * 18: # no more than 18s + inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1)) + if total_len < 24 * 2: # no less than 2s + inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1)) + + if inter_len > 2: + t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device) + t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + tgt_c2w_b = [] + tgt_intr_b = [] + for b_idx in range(cam2world.shape[0]): + tgt_c2w = [] + tgt_intr = [] + for cur_idx in range(cam2world.shape[1] - 1): + tgt_c2w.append( + interpolate_extrinsics( + cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_intr.append( + interpolate_intrinsics( + intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_c2w_b.append(torch.cat(tgt_c2w)) + tgt_intr_b.append(torch.cat(tgt_intr)) + tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4 + tgt_intr = torch.stack(tgt_intr_b) # b v 3 3 + else: + tgt_c2w = cam2world + tgt_intr = intr_normed + if trj_mode in ["interpolate_smooth", "extend"]: + tgt_c2w = _smooth_trj_fn_batch(tgt_c2w) + if trj_mode == "extend": + # apply dolly_zoom and wander in the middle frame + assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently." + mid_idx = tgt_c2w.shape[1] // 2 + c2w_wd, intr_wd = render_wander_path( + tgt_c2w[0, mid_idx], + tgt_intr[0, mid_idx], + h=in_h, + w=in_w, + num_frames=max(36, min(60, mid_idx // 2)), + max_disp=24.0, + ) + c2w_dz, intr_dz = render_dolly_zoom_path( + tgt_c2w[0, mid_idx], + tgt_intr[0, mid_idx], + h=in_h, + w=in_w, + num_frames=max(36, min(60, mid_idx // 2)), + ) + tgt_c2w = torch.cat( + [ + tgt_c2w[:, :mid_idx], + c2w_wd.unsqueeze(0), + c2w_dz.unsqueeze(0), + tgt_c2w[:, mid_idx:], + ], + dim=1, + ) + tgt_intr = torch.cat( + [ + tgt_intr[:, :mid_idx], + intr_wd.unsqueeze(0), + intr_dz.unsqueeze(0), + tgt_intr[:, mid_idx:], + ], + dim=1, + ) + elif trj_mode in ["wander", "dolly_zoom"]: + if trj_mode == "wander": + render_fn = render_wander_path + extra_kwargs = {"max_disp": 24.0} + else: + render_fn = render_dolly_zoom_path + extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0} + tgt_c2w = [] + tgt_intr = [] + for b_idx in range(cam2world.shape[0]): + c2w_i, intr_i = render_fn( + cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs + ) + tgt_c2w.append(c2w_i) + tgt_intr.append(intr_i) + tgt_c2w = torch.stack(tgt_c2w) + tgt_intr = torch.stack(tgt_intr) + elif trj_mode == "wobble_inter": + tgt_c2w, tgt_intr = render_wobble_inter_path( + cam2world=cam2world, + intr_normed=intr_normed, + inter_len=10, + n_skip=3, + ) + else: + raise Exception(f"trj mode [{trj_mode}] is not implemented.") + + _, v = tgt_c2w.shape[:2] + tgt_extr = affine_inverse(tgt_c2w) + if chunk_size is None: + chunk_size = v + chunk_size = min(v, chunk_size) + all_colors = [] + all_depths = [] + for chunk_idx in tqdm( + range(math.ceil(v / chunk_size)), + desc="Rendering novel views", + disable=(not enable_tqdm), + leave=False, + ): + s = int(chunk_idx * chunk_size) + e = int((chunk_idx + 1) * chunk_size) + cur_n_view = tgt_extr[:, s:e].shape[1] + color, depth = render_3dgs( + extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), # w2c + intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), # normed + image_shape=image_shape, + gaussian=gaussians, + num_view=cur_n_view, + **kwargs, + ) + all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view)) + all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view)) + all_colors = torch.cat(all_colors, dim=1) + all_depths = torch.cat(all_depths, dim=1) + + return all_colors, all_depths diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/head_utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/head_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1209582bc6e9b1658a0358df03f5ea79fe61daf --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/head_utils.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Activation functions +# ----------------------------------------------------------------------------- + + +def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None): + """ + Process network output to extract GS params and density values. + Density could be view-dependent as SH coefficient + + + Args: + out: Network output tensor (B, C, H, W) + activation: Activation type for 3D points + conf_activation: Activation type for confidence values + + Returns: + Tuple of (3D points tensor, confidence tensor) + """ + # Move channels from last dim to the 4th dimension => (B, H, W, C) + fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected + + # Split into xyz (first C-1 channels) and confidence (last channel) + conf_dim = 1 if conf_dim is None else conf_dim + xyz = fmap[:, :, :, :-conf_dim] + conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:] + + if activation == "norm_exp": + d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8) + xyz_normed = xyz / d + pts3d = xyz_normed * torch.expm1(d) + elif activation == "norm": + pts3d = xyz / xyz.norm(dim=-1, keepdim=True) + elif activation == "exp": + pts3d = torch.exp(xyz) + elif activation == "relu": + pts3d = F.relu(xyz) + elif activation == "sigmoid": + pts3d = torch.sigmoid(xyz) + elif activation == "linear": + pts3d = xyz + else: + raise ValueError(f"Unknown activation: {activation}") + + if conf_activation == "expp1": + conf_out = 1 + conf.exp() + elif conf_activation == "expp0": + conf_out = conf.exp() + elif conf_activation == "sigmoid": + conf_out = torch.sigmoid(conf) + elif conf_activation == "linear": + conf_out = conf + else: + raise ValueError(f"Unknown conf_activation: {conf_activation}") + + return pts3d, conf_out + + +# ----------------------------------------------------------------------------- +# Other utilities +# ----------------------------------------------------------------------------- + + +class Permute(nn.Module): + """nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage.""" + + dims: Tuple[int, ...] + + def __init__(self, dims: Tuple[int, ...]) -> None: + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override] + return x.permute(*self.dims) + + +def position_grid_to_embed( + pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100 +) -> torch.Tensor: + """ + Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC) + + Args: + pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates + embed_dim: Output channel dimension for embeddings + + Returns: + Tensor of shape (H, W, embed_dim) with positional embeddings + """ + H, W, grid_dim = pos_grid.shape + assert grid_dim == 2 + pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2) + + # Process x and y coordinates separately + emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2] + emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2] + + # Combine and reshape + emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D] + + return emb.view(H, W, embed_dim) # [H, W, D] + + +def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor: + """ + This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa + + Args: + - embed_dim: The embedding dimension. + - pos: The position to generate the embedding from. + + Returns: + - emb: The generated 1D positional embedding. + """ + assert embed_dim % 2 == 0 + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega /= embed_dim / 2.0 + omega = 1.0 / omega_0**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + return emb.float() + + +# Inspired by https://github.com/microsoft/moge + + +def create_uv_grid( + width: int, + height: int, + aspect_ratio: float = None, + dtype: torch.dtype = None, + device: torch.device = None, +) -> torch.Tensor: + """ + Create a normalized UV grid of shape (width, height, 2). + + The grid spans horizontally and vertically according to an aspect ratio, + ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right + corner is at (x_span, y_span), normalized by the diagonal of the plane. + + Args: + width (int): Number of points horizontally. + height (int): Number of points vertically. + aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height. + dtype (torch.dtype, optional): Data type of the resulting tensor. + device (torch.device, optional): Device on which the tensor is created. + + Returns: + torch.Tensor: A (width, height, 2) tensor of UV coordinates. + """ + # Derive aspect ratio if not explicitly provided + if aspect_ratio is None: + aspect_ratio = float(width) / float(height) + + # Compute normalized spans for X and Y + diag_factor = (aspect_ratio**2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + + # Establish the linspace boundaries + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + + # Generate 1D coordinates + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + + # Create 2D meshgrid (width x height) and stack into UV + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + uv_grid = torch.stack((uu, vv), dim=-1) + + return uv_grid + + +# ----------------------------------------------------------------------------- +# Interpolation (safe interpolation, avoid INT_MAX overflow) +# ----------------------------------------------------------------------------- +def custom_interpolate( + x: torch.Tensor, + size: Union[Tuple[int, int], None] = None, + scale_factor: Union[float, None] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + """ + Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate. + """ + if size is None: + assert scale_factor is not None, "Either size or scale_factor must be provided." + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [ + nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners) + for c in chunks + ] + return torch.cat(outs, dim=0).contiguous() + + return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/transform.py b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..8d732b093e5ad1578bc0ba5eb0e31b1b69b766ed --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/model/utils/transform.py @@ -0,0 +1,208 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F + + +def extri_intri_to_pose_encoding( + extrinsics, + intrinsics, + image_size_hw=None, +): + """Convert camera extrinsics and intrinsics to a compact pose encoding.""" + + # extrinsics: BxSx3x4 + # intrinsics: BxSx3x3 + R = extrinsics[:, :, :3, :3] # BxSx3x3 + T = extrinsics[:, :, :3, 3] # BxSx3 + + quat = mat_to_quat(R) + # Note the order of h and w here + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + + return pose_encoding + + +def pose_encoding_to_extri_intri( + pose_encoding, + image_size_hw=None, +): + """Convert a pose encoding back to camera extrinsics and intrinsics.""" + + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + + H, W = image_size_hw + fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) + fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1 + + return extrinsics, intrinsics + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w): + # cam_quat_xyzw: (b, n, 4) in xyzw + # c2w: (b, n, 4, 4) + b, n = cam_quat_xyzw.shape[:2] + # 1. xyzw -> wxyz + cam_quat_wxyz = torch.cat( + [ + cam_quat_xyzw[..., 3:4], # w + cam_quat_xyzw[..., 0:1], # x + cam_quat_xyzw[..., 1:2], # y + cam_quat_xyzw[..., 2:3], # z + ], + dim=-1, + ) + # 2. Quaternion to matrix + cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4) + rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3) + # 3. Transform to world space + rotmat_c2w = c2w[..., :3, :3] + rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam) + # 4. Matrix to quaternion (wxyz) + rotmat_world_flat = rotmat_world.reshape(-1, 3, 3) + world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat) + world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4) + return world_quat_wxyz diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/registry.py b/Depth-Anything-3-anysize/src/depth_anything_3/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..96450717d696b395503fedfe93af812e975d2671 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/registry.py @@ -0,0 +1,50 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +from pathlib import Path + + +def get_all_models() -> OrderedDict: + """ + Scans all YAML files in the configs directory and returns a sorted dictionary where: + - Keys are model names (YAML filenames without the .yaml extension) + - Values are absolute paths to the corresponding YAML files + """ + # Get path to the configs directory within the da3 package + # Works both in development and after pip installation + # configs_dir = files("depth_anything_3").joinpath("configs") + configs_dir = Path(__file__).resolve().parent / "configs" + + # Ensure path is a Path object for consistent cross-platform handling + configs_dir = Path(configs_dir) + + model_entries = [] + # Iterate through all items in the configs directory + for item in configs_dir.iterdir(): + # Filter for YAML files (excluding directories) + if item.is_file() and item.suffix == ".yaml": + # Extract model name (filename without .yaml extension) + model_name = item.stem + # Get absolute path (resolve() handles symlinks) + file_abs_path = str(item.resolve()) + model_entries.append((model_name, file_abs_path)) + + # Sort entries by model name and convert to OrderedDict + sorted_entries = sorted(model_entries, key=lambda x: x[0]) + return OrderedDict(sorted_entries) + + +# Global registry for external imports +MODEL_REGISTRY = get_all_models() diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07ed8b6b96d68a3f3933dfa6651875f3f88b0aef --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Services module for Depth Anything 3. +""" + +from depth_anything_3.services.backend import create_app, start_server + +__all__ = [ + start_server, + create_app, +] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/backend.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/backend.py new file mode 100644 index 0000000000000000000000000000000000000000..9e932a1023b856e4daaf026ce429ad4c7e3e77e4 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/backend.py @@ -0,0 +1,1427 @@ +# flake8: noqa: E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model backend service for Depth Anything 3. +Provides HTTP API for model inference with persistent model loading. +""" + +import os +import posixpath +import time +import uuid + +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional +from urllib.parse import quote +import numpy as np + +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse, HTMLResponse +from pydantic import BaseModel +from PIL import Image + +from ..api import DepthAnything3 +from ..utils.memory import ( + get_gpu_memory_info, + cleanup_cuda_memory, + check_memory_availability, + estimate_memory_requirement, +) + + +class InferenceRequest(BaseModel): + """Request model for inference API.""" + + image_paths: List[str] + export_dir: Optional[str] = None + export_format: str = "mini_npz-glb" + extrinsics: Optional[List[List[List[float]]]] = None + intrinsics: Optional[List[List[List[float]]]] = None + process_res: Optional[int] = None + process_res_method: str = "keep" + export_feat_layers: List[int] = [] + align_to_input_ext_scale: bool = True + # GLB export parameters + conf_thresh_percentile: float = 40.0 + num_max_points: int = 1_000_000 + show_cameras: bool = True + # Feat_vis export parameters + feat_vis_fps: int = 15 + + +class InferenceResponse(BaseModel): + """Response model for inference API.""" + + success: bool + message: str + task_id: Optional[str] = None + export_dir: Optional[str] = None + export_format: str = "mini_npz-glb" + processing_time: Optional[float] = None + + +class TaskStatus(BaseModel): + """Task status model.""" + + task_id: str + status: str # "pending", "running", "completed", "failed" + message: str + progress: Optional[float] = None # 0.0 to 1.0 + created_at: float + started_at: Optional[float] = None + completed_at: Optional[float] = None + export_dir: Optional[str] = None + request: Optional[InferenceRequest] = None # Store the original request + + # Essential task parameters + num_images: Optional[int] = None # Number of input images + export_format: Optional[str] = None # Export format + process_res_method: Optional[str] = None # Processing resolution method + video_path: Optional[str] = None # Source video path + + +class ModelBackend: + """Model backend service with persistent model loading.""" + + def __init__(self, model_dir: str, device: str = "cuda"): + self.model_dir = model_dir + self.device = device + self.model = None + self.model_loaded = False + self.load_time = None + self.load_start_time = None # Time when model loading started + self.load_completed_time = None # Time when model loading completed + self.last_used = None + + def load_model(self): + """Load model if not already loaded.""" + if self.model_loaded and self.model is not None: + self.last_used = time.time() + return self.model + + try: + print(f"Loading model from {self.model_dir}...") + self.load_start_time = time.time() + start_time = time.time() + + self.model = DepthAnything3.from_pretrained(self.model_dir).to(self.device) + self.model.eval() + + self.model_loaded = True + self.load_time = time.time() - start_time + self.load_completed_time = time.time() + self.last_used = time.time() + + print(f"Model loaded successfully in {self.load_time:.2f}s") + return self.model + + except Exception as e: + print(f"Failed to load model: {e}") + raise e + + def get_model(self): + """Get model, loading if necessary.""" + if not self.model_loaded: + return self.load_model() + self.last_used = time.time() + return self.model + + def get_status(self) -> Dict[str, Any]: + """Get backend status information.""" + # Calculate uptime from when model loading completed + uptime = 0 + if self.model_loaded and self.load_completed_time: + uptime = time.time() - self.load_completed_time + + return { + "model_loaded": self.model_loaded, + "model_dir": self.model_dir, + "device": self.device, + "load_time": self.load_time, + "last_used": self.last_used, + "uptime": uptime, + } + + +# Global backend instance +_backend: Optional[ModelBackend] = None +_app: Optional[FastAPI] = None +_tasks: Dict[str, TaskStatus] = {} +_executor = ThreadPoolExecutor(max_workers=1) # Restrict to single-task execution +_running_task_id: Optional[str] = None # Currently running task ID +_task_queue: List[str] = [] # Pending task queue + +# Task cleanup configuration +MAX_TASK_HISTORY = 100 # Maximum number of tasks to keep in memory +CLEANUP_INTERVAL = 300 # Cleanup interval in seconds (5 minutes) + + +def _process_next_task(): + """Process the next task in the queue.""" + global _task_queue, _running_task_id + + if not _task_queue or _running_task_id is not None: + return + + # Get next task from queue + task_id = _task_queue.pop(0) + + # Get task request from tasks dict (we need to store the request) + if task_id not in _tasks: + return + + # Submit task to executor + _executor.submit(_run_inference_task, task_id) + + +# get_gpu_memory_info imported from depth_anything_3.utils.memory + + +# cleanup_cuda_memory imported from depth_anything_3.utils.memory + + +# check_memory_availability imported from depth_anything_3.utils.memory + + +# estimate_memory_requirement imported from depth_anything_3.utils.memory + + +def _run_inference_task(task_id: str): + """Run inference task in background thread with OOM protection.""" + global _tasks, _backend, _running_task_id, _task_queue + + model = None + inference_started = False + start_time = time.time() + + try: + # Get task request + if task_id not in _tasks or _tasks[task_id].request is None: + print(f"[{task_id}] Task not found or request missing") + return + + request = _tasks[task_id].request + num_images = len(request.image_paths) + + # Set current running task + _running_task_id = task_id + + # Update task status to running + _tasks[task_id].status = "running" + _tasks[task_id].started_at = start_time + _tasks[task_id].message = f"[{task_id}] Starting inference on {num_images} frames..." + print(f"[{task_id}] Starting inference on {num_images} frames") + + # Pre-inference cleanup to ensure maximum available memory + print(f"[{task_id}] Pre-inference cleanup...") + cleanup_cuda_memory() + + # Check memory availability + effective_res = request.process_res + if not effective_res or effective_res <= 0: + try: + first_path = request.image_paths[0] + with Image.open(first_path) as img: + effective_res = max(img.size) + except Exception: + effective_res = 504 # Fall back to baseline heuristic + + estimated_memory = estimate_memory_requirement(num_images, effective_res) + mem_available, mem_msg = check_memory_availability(estimated_memory) + print(f"[{task_id}] {mem_msg}") + + if not mem_available: + # Try aggressive cleanup + print(f"[{task_id}] Insufficient memory, attempting aggressive cleanup...") + cleanup_cuda_memory() + time.sleep(0.5) # Give system time to reclaim memory + + # Check again + mem_available, mem_msg = check_memory_availability(estimated_memory) + if not mem_available: + raise RuntimeError( + f"Insufficient GPU memory after cleanup. {mem_msg}\n" + f"Suggestions:\n" + f" 1. Reduce process_res (current: {request.process_res})\n" + f" 2. Process fewer images at once (current: {num_images})\n" + f" 3. Clear other GPU processes" + ) + + # Get model (with error handling) + print(f"[{task_id}] Loading model...") + _tasks[task_id].message = f"[{task_id}] Loading model..." + _tasks[task_id].progress = 0.1 + + try: + model = _backend.get_model() + except RuntimeError as e: + if "out of memory" in str(e).lower(): + cleanup_cuda_memory() + raise RuntimeError( + f"OOM during model loading: {str(e)}\n" + f"Try reducing the batch size or resolution." + ) + raise + + print(f"[{task_id}] Model loaded successfully") + _tasks[task_id].progress = 0.2 + + # Prepare inference parameters + inference_kwargs = { + "image": request.image_paths, + "export_format": request.export_format, + "process_res": request.process_res, + "process_res_method": request.process_res_method, + "export_feat_layers": request.export_feat_layers, + "align_to_input_ext_scale": request.align_to_input_ext_scale, + "conf_thresh_percentile": request.conf_thresh_percentile, + "num_max_points": request.num_max_points, + "show_cameras": request.show_cameras, + "feat_vis_fps": request.feat_vis_fps, + } + + if request.export_dir: + inference_kwargs["export_dir"] = request.export_dir + + if request.extrinsics: + inference_kwargs["extrinsics"] = np.array(request.extrinsics, dtype=np.float32) + + if request.intrinsics: + inference_kwargs["intrinsics"] = np.array(request.intrinsics, dtype=np.float32) + + # Run inference with timing + inference_start_time = time.time() + print(f"[{task_id}] Running model inference...") + _tasks[task_id].message = f"[{task_id}] Running model inference on {num_images} images..." + _tasks[task_id].progress = 0.3 + + inference_started = True + + try: + model.inference(**inference_kwargs) + inference_time = time.time() - inference_start_time + avg_time_per_image = inference_time / num_images if num_images > 0 else 0 + + print( + f"[{task_id}] Inference completed in {inference_time:.2f}s " + f"({avg_time_per_image:.2f}s per image)" + ) + + except RuntimeError as e: + if "out of memory" in str(e).lower(): + cleanup_cuda_memory() + raise RuntimeError( + f"OOM during inference: {str(e)}\n" + f"Settings: {num_images} images, resolution={request.process_res}\n" + f"Suggestions:\n" + f" 1. Reduce process_res to {int(request.process_res * 0.75)}\n" + f" 2. Process images in smaller batches\n" + f" 3. Use process_res_method='resize' instead of 'upper_bound_resize'" + ) + raise + + _tasks[task_id].progress = 0.9 + + # Post-inference cleanup + print(f"[{task_id}] Post-inference cleanup...") + cleanup_cuda_memory() + + # Calculate total processing time + total_time = time.time() - start_time + + # Update task status to completed + _tasks[task_id].status = "completed" + _tasks[task_id].completed_at = time.time() + _tasks[task_id].message = ( + f"[{task_id}] Completed in {total_time:.2f}s " f"({avg_time_per_image:.2f}s per image)" + ) + _tasks[task_id].progress = 1.0 + _tasks[task_id].export_dir = request.export_dir + + # Clear running state + _running_task_id = None + + # Process next task in queue + _process_next_task() + + print(f"[{task_id}] Task completed successfully") + print( + f"[{task_id}] Total time: {total_time:.2f}s, " + f"Inference time: {inference_time:.2f}s, " + f"Avg per image: {avg_time_per_image:.2f}s" + ) + + except Exception as e: + # Update task status to failed + error_msg = str(e) + total_time = time.time() - start_time + + print(f"[{task_id}] Task failed after {total_time:.2f}s: {error_msg}") + + # Always attempt cleanup on failure + cleanup_cuda_memory() + + _tasks[task_id].status = "failed" + _tasks[task_id].completed_at = time.time() + _tasks[task_id].message = f"[{task_id}] Failed after {total_time:.2f}s: {error_msg}" + + # Clear running state + _running_task_id = None + + # Process next task in queue + _process_next_task() + + finally: + # Final cleanup in finally block to ensure it always runs + # This is critical for releasing resources even if unexpected errors occur + try: + if inference_started: + print(f"[{task_id}] Final cleanup in finally block...") + cleanup_cuda_memory() + except Exception as e: + print(f"[{task_id}] Warning: Finally block cleanup failed: {e}") + + # Schedule cleanup after task completion + _schedule_task_cleanup() + + +def _cleanup_old_tasks(): + """Clean up old completed/failed tasks to prevent memory buildup.""" + global _tasks + + current_time = time.time() + tasks_to_remove = [] + + # Find tasks to remove - more aggressive cleanup + for task_id, task in _tasks.items(): + # Remove completed/failed tasks older than 10 minutes (instead of 1 hour) + if ( + task.status in ["completed", "failed"] + and task.completed_at + and current_time - task.completed_at > 600 + ): # 10 minutes + tasks_to_remove.append(task_id) + + # Remove old tasks + for task_id in tasks_to_remove: + del _tasks[task_id] + print(f"[CLEANUP] Removed old task: {task_id}") + + # If still too many tasks, remove oldest completed/failed tasks + if len(_tasks) > MAX_TASK_HISTORY: + completed_tasks = [ + (task_id, task) + for task_id, task in _tasks.items() + if task.status in ["completed", "failed"] + ] + completed_tasks.sort(key=lambda x: x[1].completed_at or 0) + + excess_count = len(_tasks) - MAX_TASK_HISTORY + for i in range(min(excess_count, len(completed_tasks))): + task_id = completed_tasks[i][0] + del _tasks[task_id] + print(f"[CLEANUP] Removed excess task: {task_id}") + + # Count active tasks (only pending and running) + active_count = sum(1 for task in _tasks.values() if task.status in ["pending", "running"]) + print( + "[CLEANUP] Task cleanup completed. " + f"Total tasks: {len(_tasks)}, Active tasks: {active_count}" + ) + + +def _schedule_task_cleanup(): + """Schedule task cleanup in background.""" + + def cleanup_worker(): + try: + time.sleep(2) # Small delay to ensure task status is updated + _cleanup_old_tasks() + except Exception as e: + print(f"[CLEANUP] Cleanup worker failed: {e}") + + # Run cleanup in background thread + _executor.submit(cleanup_worker) + + +# ============================================================================ +# Gallery utilities (extracted from gallery.py) +# ============================================================================ + +GALLERY_IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp") + + +def _load_gallery_html() -> str: + """ + Load and modify gallery HTML to work under /gallery/ subdirectory. + Replaces API paths from root to /gallery/ prefix. + """ + from ..services.gallery import HTML_PAGE + + # Replace API paths to be under /gallery/ subdirectory + html = ( + HTML_PAGE.replace("fetch('/manifest.json'", "fetch('/gallery/manifest.json'") + .replace("fetch('/manifest/'+", "fetch('/gallery/manifest/'+") + .replace( + "if(location.pathname!=\"/\")history.replaceState(null,'','/'+location.search)", + "if(!location.pathname.startsWith(\"/gallery\"))history.replaceState(null,'','/gallery/'+location.search)", + ) + ) + + return html + + +def _gallery_url_join(*parts: str) -> str: + """Join URL parts safely.""" + norm = posixpath.join(*[p.replace("\\", "/") for p in parts]) + segs = [s for s in norm.split("/") if s not in ("", ".")] + return "/".join(quote(s) for s in segs) + + +def _is_plain_name(name: str) -> bool: + """Check if name is safe for use in paths.""" + return all(c not in name for c in ("/", "\\")) and name not in (".", "..") + + +def build_group_list(root_dir: str) -> dict: + """Build list of groups from gallery directory.""" + groups = [] + try: + for gname in sorted(os.listdir(root_dir)): + gpath = os.path.join(root_dir, gname) + if not os.path.isdir(gpath): + continue + has_scene = False + try: + for sname in os.listdir(gpath): + spath = os.path.join(gpath, sname) + if not os.path.isdir(spath): + continue + if os.path.exists(os.path.join(spath, "scene.glb")) and os.path.exists( + os.path.join(spath, "scene.jpg") + ): + has_scene = True + break + except Exception: + pass + if has_scene: + groups.append({"id": gname, "title": gname}) + except Exception as e: + print(f"[warn] build_group_list failed: {e}") + return {"groups": groups} + + +def build_group_manifest(root_dir: str, group: str) -> dict: + """Build manifest for a specific group.""" + items = [] + gpath = os.path.join(root_dir, group) + try: + if not os.path.isdir(gpath): + return {"group": group, "items": []} + for sname in sorted(os.listdir(gpath)): + spath = os.path.join(gpath, sname) + if not os.path.isdir(spath): + continue + glb_fs = os.path.join(spath, "scene.glb") + jpg_fs = os.path.join(spath, "scene.jpg") + if not (os.path.exists(glb_fs) and os.path.exists(jpg_fs)): + continue + depth_images = [] + dpath = os.path.join(spath, "depth_vis") + if os.path.isdir(dpath): + files = [ + f + for f in os.listdir(dpath) + if os.path.splitext(f)[1].lower() in GALLERY_IMAGE_EXTS + ] + for fn in sorted(files): + depth_images.append( + "/gallery/" + _gallery_url_join(group, sname, "depth_vis", fn) + ) + items.append( + { + "id": sname, + "title": sname, + "model": "/gallery/" + _gallery_url_join(group, sname, "scene.glb"), + "thumbnail": "/gallery/" + _gallery_url_join(group, sname, "scene.jpg"), + "depth_images": depth_images, + } + ) + except Exception as e: + print(f"[warn] build_group_manifest failed for {group}: {e}") + return {"group": group, "items": items} + + +def create_app(model_dir: str, device: str = "cuda", gallery_dir: Optional[str] = None) -> FastAPI: + """Create FastAPI application with model backend.""" + global _backend, _app + + _backend = ModelBackend(model_dir, device) + _app = FastAPI( + title="Depth Anything 3 Backend", + description="Model inference service for Depth Anything 3", + version="1.0.0", + ) + + # Store gallery directory globally for use in routes + _gallery_dir = gallery_dir + + @_app.get("/", response_class=HTMLResponse) + async def root(): + """Home page with navigation to dashboard and gallery.""" + html_content = ( + """ + + + + + + Depth Anything 3 Backend + + + +
+

Depth Anything 3

+

Model Backend Service

+ + +
+ + + """ + ) + return HTMLResponse(html_content) + + @_app.get("/dashboard", response_class=HTMLResponse) + async def dashboard(): + """HTML dashboard for monitoring backend status and tasks.""" + if _backend is None: + return HTMLResponse("

Backend not initialized

", status_code=500) + + # Get backend status + status = _backend.get_status() + + # Safely format status values + if status["load_time"] is not None: + load_time_str = f"{status['load_time']:.2f}s" + else: + load_time_str = "Not loaded" + + if status["uptime"] is not None: + uptime_str = f"{status['uptime']:.2f}s" + else: + uptime_str = "Not running" + + # Get tasks information + active_tasks = [task for task in _tasks.values() if task.status in ["pending", "running"]] + completed_tasks = [ + task for task in _tasks.values() if task.status in ["completed", "failed"] + ] + + # Generate task HTML + active_tasks_html = "" + if active_tasks: + for task in active_tasks: + task_details = f""" +
+
+ {task.task_id} + {task.status} +
+
{task.message}
+
+ + Images: {task.num_images or 'N/A'} | + Format: {task.export_format or 'N/A'} | + Method: {task.process_res_method or 'N/A'} | + Export Dir: {task.export_dir or 'N/A'} + + {f'
Video: {task.video_path}' if task.video_path else ''} +
+
+ """ + active_tasks_html += task_details + else: + active_tasks_html = "

No active tasks

" + + completed_tasks_html = "" + if completed_tasks: + for task in completed_tasks[-10:]: + task_details = f""" +
+
+ {task.task_id} + {task.status} +
+
{task.message}
+
+ + Images: {task.num_images or 'N/A'} | + Format: {task.export_format or 'N/A'} | + Method: {task.process_res_method or 'N/A'} | + Export Dir: {task.export_dir or 'N/A'} + + {f'
Video: {task.video_path}' if task.video_path else ''} +
+
+ """ + completed_tasks_html += task_details + else: + completed_tasks_html = "

No completed tasks

" + + # Generate HTML + html_content = f""" + + + + + + Depth Anything 3 Backend Dashboard + + + +
+
+

Depth Anything 3 Backend Dashboard

+

Real-time monitoring of model status and inference tasks

+
+ +
+
+

Model Status

+
+ Status: + + {'Online' if status['model_loaded'] else 'Offline'} + +
+
+ Model Directory: + {status['model_dir']} +
+
+ Device: + {status['device']} +
+
+ Load Time: + {load_time_str} +
+
+ Uptime: + {uptime_str} +
+
+ +
+

Task Summary

+
+ Active Tasks: + {len(active_tasks)} +
+
+ Completed Tasks: + {len(completed_tasks)} +
+
+ Total Tasks: + {len(_tasks)} +
+
+
+ +
+

Active Tasks

+ + +
Last updated: {time.strftime('%Y-%m-%d %H:%M:%S')}
+ + {active_tasks_html} +
+ +
+

Recent Completed Tasks

+ {completed_tasks_html} +
+
+ + + + + """ + + return HTMLResponse(html_content) + + @_app.get("/status") + async def get_status(): + """Get backend status with GPU memory information.""" + if _backend is None: + raise HTTPException(status_code=500, detail="Backend not initialized") + + status = _backend.get_status() + + # Add GPU memory information + gpu_memory = get_gpu_memory_info() + if gpu_memory: + status["gpu_memory"] = { + "total_gb": round(gpu_memory["total_gb"], 2), + "allocated_gb": round(gpu_memory["allocated_gb"], 2), + "reserved_gb": round(gpu_memory["reserved_gb"], 2), + "free_gb": round(gpu_memory["free_gb"], 2), + "utilization_percent": round(gpu_memory["utilization"], 1), + } + else: + status["gpu_memory"] = None + + return status + + @_app.post("/inference", response_model=InferenceResponse) + async def run_inference(request: InferenceRequest): + """Submit inference task and return task ID.""" + global _running_task_id + + if _backend is None: + raise HTTPException(status_code=500, detail="Backend not initialized") + + # Generate unique task ID + task_id = str(uuid.uuid4()) + + # Create task status + if _running_task_id is not None: + status_msg = f"[{task_id}] Task queued (waiting for {_running_task_id} to complete)" + else: + status_msg = f"[{task_id}] Task submitted" + + _tasks[task_id] = TaskStatus( + task_id=task_id, + status="pending", + message=status_msg, + created_at=time.time(), + export_dir=request.export_dir, + request=request, + # Record essential parameters + num_images=len(request.image_paths), + export_format=request.export_format, + process_res_method=request.process_res_method, + video_path=( + request.image_paths[0] if request.image_paths else None + ), # Use first image path as video reference + ) + + # Add task to queue + _task_queue.append(task_id) + + # If no task is running, start processing the queue + if _running_task_id is None: + _process_next_task() + + return InferenceResponse( + success=True, + message="Task submitted successfully", + task_id=task_id, + export_dir=request.export_dir, + export_format=request.export_format, + ) + + @_app.get("/task/{task_id}", response_model=TaskStatus) + async def get_task_status(task_id: str): + """Get task status by task ID.""" + if task_id not in _tasks: + raise HTTPException(status_code=404, detail="Task not found") + + return _tasks[task_id] + + @_app.get("/gpu-memory") + async def get_gpu_memory(): + """Get detailed GPU memory information.""" + gpu_memory = get_gpu_memory_info() + if gpu_memory is None: + return { + "available": False, + "message": "CUDA not available or memory info cannot be retrieved", + } + + return { + "available": True, + "total_gb": round(gpu_memory["total_gb"], 2), + "allocated_gb": round(gpu_memory["allocated_gb"], 2), + "reserved_gb": round(gpu_memory["reserved_gb"], 2), + "free_gb": round(gpu_memory["free_gb"], 2), + "utilization_percent": round(gpu_memory["utilization"], 1), + "status": ( + "healthy" + if gpu_memory["utilization"] < 80 + else "warning" if gpu_memory["utilization"] < 95 else "critical" + ), + } + + @_app.get("/tasks") + async def list_tasks(): + """List all tasks.""" + # Separate active and completed tasks + active_tasks = [task for task in _tasks.values() if task.status in ["pending", "running"]] + completed_tasks = [ + task for task in _tasks.values() if task.status in ["completed", "failed"] + ] + + return { + "tasks": list(_tasks.values()), + "active_tasks": active_tasks, + "completed_tasks": completed_tasks, + "active_count": len(active_tasks), + "total_count": len(_tasks), + } + + @_app.post("/cleanup") + async def manual_cleanup(): + """Manually trigger task cleanup.""" + try: + _cleanup_old_tasks() + return {"message": "Cleanup completed", "active_tasks": len(_tasks)} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}") + + @_app.delete("/task/{task_id}") + async def delete_task(task_id: str): + """Delete a specific task.""" + if task_id not in _tasks: + raise HTTPException(status_code=404, detail="Task not found") + + # Only allow deletion of completed/failed tasks + if _tasks[task_id].status not in ["completed", "failed"]: + raise HTTPException(status_code=400, detail="Cannot delete running or pending tasks") + + del _tasks[task_id] + return {"message": f"Task {task_id} deleted successfully"} + + @_app.post("/reload") + async def reload_model(): + """Reload the model.""" + if _backend is None: + raise HTTPException(status_code=500, detail="Backend not initialized") + + try: + _backend.model = None + _backend.model_loaded = False + _backend.load_model() + return {"message": "Model reloaded successfully"} + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}") + + # ============================================================================ + # Gallery routes + # ============================================================================ + + if _gallery_dir and os.path.exists(_gallery_dir): + # Load gallery HTML page (with modified paths for /gallery/ subdirectory) + _gallery_html = _load_gallery_html() + + @_app.get("/gallery/", response_class=HTMLResponse) + @_app.get("/gallery", response_class=HTMLResponse) + async def gallery_home(): + """Gallery home page.""" + return HTMLResponse(_gallery_html) + + @_app.get("/gallery/manifest.json") + async def gallery_manifest(): + """Get gallery group list.""" + try: + return build_group_list(_gallery_dir) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to build group list: {str(e)}" + ) + + @_app.get("/gallery/manifest/{group}.json") + async def gallery_group_manifest(group: str): + """Get manifest for a specific group.""" + if not _is_plain_name(group): + raise HTTPException(status_code=400, detail="Invalid group name") + try: + return build_group_manifest(_gallery_dir, group) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to build group manifest: {str(e)}" + ) + + @_app.get("/gallery/{path:path}") + async def gallery_files(path: str): + """Serve gallery static files (GLB, JPG, etc.).""" + # Security check: prevent directory traversal + path_parts = path.split("/") + if any(not _is_plain_name(part) for part in path_parts if part): + raise HTTPException(status_code=400, detail="Invalid path") + + file_path = os.path.join(_gallery_dir, *path_parts) + + # Ensure the file is within gallery directory + real_file_path = os.path.realpath(file_path) + real_gallery_dir = os.path.realpath(_gallery_dir) + if not real_file_path.startswith(real_gallery_dir): + raise HTTPException(status_code=403, detail="Access denied") + + if not os.path.exists(file_path) or not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="File not found") + + return FileResponse(file_path) + + return _app + + +def start_server( + model_dir: str, + device: str = "cuda", + host: str = "127.0.0.1", + port: int = 8000, + gallery_dir: Optional[str] = None, +): + """Start the backend server.""" + app = create_app(model_dir, device, gallery_dir) + + print("Starting Depth Anything 3 Backend...") + print(f"Model directory: {model_dir}") + print(f"Device: {device}") + print(f"Server: http://{host}:{port}") + print(f"Dashboard: http://{host}:{port}/dashboard") + print(f"API Status: http://{host}:{port}/status") + + if gallery_dir and os.path.exists(gallery_dir): + print(f"Gallery: http://{host}:{port}/gallery/") + + print("=" * 60) + print("Backend is running! You can now:") + print(f" β€’ Open home page: http://{host}:{port}") + print(f" β€’ Open dashboard: http://{host}:{port}/dashboard") + print(f" β€’ Check API status: http://{host}:{port}/status") + + if gallery_dir and os.path.exists(gallery_dir): + print(f" β€’ Browse gallery: http://{host}:{port}/gallery/") + + print(" β€’ Submit inference tasks via API") + print("=" * 60) + + uvicorn.run(app, host=host, port=port, log_level="info") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Depth Anything 3 Backend Server") + parser.add_argument("--model-dir", required=True, help="Model directory path") + parser.add_argument("--device", default="cuda", help="Device to use") + parser.add_argument("--host", default="127.0.0.1", help="Host to bind to") + parser.add_argument("--port", type=int, default=8000, help="Port to bind to") + parser.add_argument("--gallery-dir", help="Gallery directory path (optional)") + + args = parser.parse_args() + start_server(args.model_dir, args.device, args.host, args.port, args.gallery_dir) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py new file mode 100644 index 0000000000000000000000000000000000000000..f72bb5e5f6defbc24cf2278a53b7162a8ad5519d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py @@ -0,0 +1,806 @@ +#!/usr/bin/env python3 +# flake8: noqa: E501 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Depth Anything 3 Gallery Server (two-level, single-file) +Now supports paginated depth preview (4 per page). +""" + +import argparse +import json +import mimetypes +import os +import posixpath +import sys +from functools import partial +from http import HTTPStatus +from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer +from urllib.parse import quote, unquote + +# ------------------------------ Embedded HTML ------------------------------ # + +HTML_PAGE = r""" + + + + Depth Anything 3 Gallery + + + + + + + +
+
+ +

Depth Anything 3 Gallery

+ + +
+
Level 1 shows groups only; click a group to browse scenes and previews.
+
+ +
+ +
+

+ 🎯 Depth Anything 3 Gallery +

+

+ Explore 3D reconstructions and depth visualizations from Depth Anything 3. + Browse through groups of scenes, preview 3D models, and examine depth maps interactively. +

+
+ +
+ + +
+ + +
+ + + + + + + + +""" + +# ------------------------------ Utilities ------------------------------ # + +IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp") + + +def _url_join(*parts: str) -> str: + norm = posixpath.join(*[p.replace("\\", "/") for p in parts]) + segs = [s for s in norm.split("/") if s not in ("", ".")] + return "/".join(quote(s) for s in segs) + + +def _is_plain_name(name: str) -> bool: + return all(c not in name for c in ("/", "\\")) and name not in (".", "..") + + +def build_group_list(root_dir: str) -> dict: + groups = [] + try: + for gname in sorted(os.listdir(root_dir)): + gpath = os.path.join(root_dir, gname) + if not os.path.isdir(gpath): + continue + has_scene = False + try: + for sname in os.listdir(gpath): + spath = os.path.join(gpath, sname) + if not os.path.isdir(spath): + continue + if os.path.exists(os.path.join(spath, "scene.glb")) and os.path.exists( + os.path.join(spath, "scene.jpg") + ): + has_scene = True + break + except Exception: + pass + if has_scene: + groups.append({"id": gname, "title": gname}) + except Exception as e: + print(f"[warn] build_group_list failed: {e}", file=sys.stderr) + return {"groups": groups} + + +def build_group_manifest(root_dir: str, group: str) -> dict: + items = [] + gpath = os.path.join(root_dir, group) + try: + if not os.path.isdir(gpath): + return {"group": group, "items": []} + for sname in sorted(os.listdir(gpath)): + spath = os.path.join(gpath, sname) + if not os.path.isdir(spath): + continue + glb_fs = os.path.join(spath, "scene.glb") + jpg_fs = os.path.join(spath, "scene.jpg") + if not (os.path.exists(glb_fs) and os.path.exists(jpg_fs)): + continue + depth_images = [] + dpath = os.path.join(spath, "depth_vis") + if os.path.isdir(dpath): + files = [ + f for f in os.listdir(dpath) if os.path.splitext(f)[1].lower() in IMAGE_EXTS + ] + for fn in sorted(files): + depth_images.append("/" + _url_join(group, sname, "depth_vis", fn)) + items.append( + { + "id": sname, + "title": sname, + "model": "/" + _url_join(group, sname, "scene.glb"), + "thumbnail": "/" + _url_join(group, sname, "scene.jpg"), + "depth_images": depth_images, + } + ) + except Exception as e: + print(f"[warn] build_group_manifest failed for {group}: {e}", file=sys.stderr) + return {"group": group, "items": items} + + +class GalleryHandler(SimpleHTTPRequestHandler): + def __init__(self, *args, directory=None, **kwargs): + super().__init__(*args, directory=directory, **kwargs) + + def do_GET(self): + if self.path in ("/", "/index.html") or self.path.startswith("/?"): + content = HTML_PAGE.encode("utf-8") + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", "text/html; charset=utf-8") + self.send_header("Content-Length", str(len(content))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(content) + return + if self.path == "/manifest.json": + data = json.dumps( + build_group_list(self.directory), ensure_ascii=False, indent=2 + ).encode("utf-8") + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", "application/json; charset=utf-8") + self.send_header("Content-Length", str(len(data))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(data) + return + if self.path.startswith("/manifest/") and self.path.endswith(".json"): + group_enc = self.path[len("/manifest/") : -len(".json")] + try: + group = unquote(group_enc) + except Exception: + group = group_enc + if not _is_plain_name(group): + self.send_error(HTTPStatus.BAD_REQUEST, "Invalid group name") + return + data = json.dumps( + build_group_manifest(self.directory, group), ensure_ascii=False, indent=2 + ).encode("utf-8") + self.send_response(HTTPStatus.OK) + self.send_header("Content-Type", "application/json; charset=utf-8") + self.send_header("Content-Length", str(len(data))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(data) + return + if self.path == "/favicon.ico": + self.send_response(HTTPStatus.NO_CONTENT) + self.end_headers() + return + return super().do_GET() + + def list_directory(self, path): + self.send_error(HTTPStatus.NOT_FOUND, "Directory listing disabled") + return None + + +def gallery(): + parser = argparse.ArgumentParser( + description="Depth Anything 3 Gallery Server (two-level, with pagination)" + ) + parser.add_argument( + "-d", "--dir", required=True, help="Gallery root directory (two-level: group/scene)" + ) + parser.add_argument("-p", "--port", type=int, default=8000, help="Port (default 8000)") + parser.add_argument("--host", default="127.0.0.1", help="Host address (default 127.0.0.1)") + parser.add_argument("--open", action="store_true", help="Open browser after launch") + args = parser.parse_args() + + root_dir = os.path.abspath(args.dir) + if not os.path.isdir(root_dir): + print(f"[error] Directory not found: {root_dir}", file=sys.stderr) + sys.exit(1) + + Handler = partial(GalleryHandler, directory=root_dir) + server = ThreadingHTTPServer((args.host, args.port), Handler) + + addr = f"http://{args.host}:{args.port}/" + print(f"[info] Serving gallery from: {root_dir}") + print(f"[info] Open: {addr}") + + if args.open: + try: + import webbrowser + + webbrowser.open(addr) + except Exception as e: + print(f"[warn] Failed to open browser: {e}", file=sys.stderr) + + try: + server.serve_forever() + except KeyboardInterrupt: + print("\n[info] Shutting down...") + finally: + server.server_close() + + +def main(): + """Main entry point for gallery server.""" + mimetypes.add_type("model/gltf-binary", ".glb") + gallery() + + +if __name__ == "__main__": + main() diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py new file mode 100644 index 0000000000000000000000000000000000000000..07ca1657a43c7407bf8eaee080adde466480e417 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py @@ -0,0 +1,225 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unified Inference Service +Provides unified interface for local and remote inference +""" + +from typing import Any, Dict, List, Optional, Union +import numpy as np +import requests +import typer + +from ..api import DepthAnything3 + + +class InferenceService: + """Unified inference service class""" + + def __init__(self, model_dir: str, device: str = "cuda"): + self.model_dir = model_dir + self.device = device + self.model = None + + def load_model(self): + """Load model""" + if self.model is None: + typer.echo(f"Loading model from {self.model_dir}...") + self.model = DepthAnything3.from_pretrained(self.model_dir).to(self.device) + return self.model + + def run_local_inference( + self, + image_paths: List[str], + export_dir: str, + export_format: str = "mini_npz-glb", + process_res: Optional[int] = None, + process_res_method: str = "keep", + export_feat_layers: List[int] = None, + extrinsics: Optional[np.ndarray] = None, + intrinsics: Optional[np.ndarray] = None, + align_to_input_ext_scale: bool = True, + conf_thresh_percentile: float = 40.0, + num_max_points: int = 1_000_000, + show_cameras: bool = True, + feat_vis_fps: int = 15, + ) -> Any: + """Run local inference""" + if export_feat_layers is None: + export_feat_layers = [] + + model = self.load_model() + + # Prepare inference parameters + inference_kwargs = { + "image": image_paths, + "export_dir": export_dir, + "export_format": export_format, + "process_res": process_res, + "process_res_method": process_res_method, + "export_feat_layers": export_feat_layers, + "align_to_input_ext_scale": align_to_input_ext_scale, + "conf_thresh_percentile": conf_thresh_percentile, + "num_max_points": num_max_points, + "show_cameras": show_cameras, + "feat_vis_fps": feat_vis_fps, + } + + # Add pose data (if exists) + if extrinsics is not None: + inference_kwargs["extrinsics"] = extrinsics + if intrinsics is not None: + inference_kwargs["intrinsics"] = intrinsics + + # Run inference + typer.echo(f"Running inference on {len(image_paths)} images...") + prediction = model.inference(**inference_kwargs) + + typer.echo(f"Results saved to {export_dir}") + typer.echo(f"Export format: {export_format}") + + return prediction + + def run_backend_inference( + self, + image_paths: List[str], + export_dir: str, + backend_url: str, + export_format: str = "mini_npz-glb", + process_res: Optional[int] = None, + process_res_method: str = "keep", + export_feat_layers: List[int] = None, + extrinsics: Optional[np.ndarray] = None, + intrinsics: Optional[np.ndarray] = None, + align_to_input_ext_scale: bool = True, + conf_thresh_percentile: float = 40.0, + num_max_points: int = 1_000_000, + show_cameras: bool = True, + feat_vis_fps: int = 15, + ) -> Dict[str, Any]: + """Run backend inference""" + if export_feat_layers is None: + export_feat_layers = [] + + # Check backend status + if not self._check_backend_status(backend_url): + raise typer.BadParameter(f"Backend service is not running at {backend_url}") + + # Prepare payload + payload = { + "image_paths": image_paths, + "export_dir": export_dir, + "export_format": export_format, + "process_res": process_res, + "process_res_method": process_res_method, + "export_feat_layers": export_feat_layers, + "align_to_input_ext_scale": align_to_input_ext_scale, + "conf_thresh_percentile": conf_thresh_percentile, + "num_max_points": num_max_points, + "show_cameras": show_cameras, + "feat_vis_fps": feat_vis_fps, + } + + # Add pose data (if exists) + if extrinsics is not None: + payload["extrinsics"] = [ext.astype(np.float64).tolist() for ext in extrinsics] + if intrinsics is not None: + payload["intrinsics"] = [intr.astype(np.float64).tolist() for intr in intrinsics] + + # Submit task + typer.echo("Submitting inference task to backend...") + try: + response = requests.post(f"{backend_url}/inference", json=payload, timeout=30) + response.raise_for_status() + result = response.json() + + if result["success"]: + task_id = result["task_id"] + typer.echo("Task submitted successfully!") + typer.echo(f"Task ID: {task_id}") + typer.echo(f"Results will be saved to: {export_dir}") + typer.echo(f"Check backend logs for progress updates with task ID: {task_id}") + return result + else: + raise typer.BadParameter( + f"Backend inference submission failed: {result['message']}" + ) + except requests.exceptions.RequestException as e: + raise typer.BadParameter(f"Backend inference submission failed: {e}") + + def _check_backend_status(self, backend_url: str) -> bool: + """Check backend status""" + try: + response = requests.get(f"{backend_url}/status", timeout=5) + return response.status_code == 200 + except Exception: + return False + + +def run_inference( + image_paths: List[str], + export_dir: str, + model_dir: str, + device: str = "cuda", + backend_url: Optional[str] = None, + export_format: str = "mini_npz-glb", + process_res: Optional[int] = None, + process_res_method: str = "keep", + export_feat_layers: List[int] = None, + extrinsics: Optional[np.ndarray] = None, + intrinsics: Optional[np.ndarray] = None, + align_to_input_ext_scale: bool = True, + conf_thresh_percentile: float = 40.0, + num_max_points: int = 1_000_000, + show_cameras: bool = True, + feat_vis_fps: int = 15, +) -> Union[Any, Dict[str, Any]]: + """Unified inference interface""" + + service = InferenceService(model_dir, device) + + if backend_url: + return service.run_backend_inference( + image_paths=image_paths, + export_dir=export_dir, + backend_url=backend_url, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + extrinsics=extrinsics, + intrinsics=intrinsics, + align_to_input_ext_scale=align_to_input_ext_scale, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) + else: + return service.run_local_inference( + image_paths=image_paths, + export_dir=export_dir, + export_format=export_format, + process_res=process_res, + process_res_method=process_res_method, + export_feat_layers=export_feat_layers, + extrinsics=extrinsics, + intrinsics=intrinsics, + align_to_input_ext_scale=align_to_input_ext_scale, + conf_thresh_percentile=conf_thresh_percentile, + num_max_points=num_max_points, + show_cameras=show_cameras, + feat_vis_fps=feat_vis_fps, + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0536b51e5c71bb6fb9d10f8e74fbb12fc42d31 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py @@ -0,0 +1,266 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Input Processing Service +Handles different types of inputs (image, images, colmap, video) +""" + +import glob +import os +from typing import List, Tuple +import cv2 +import numpy as np +import typer + +from ..utils.read_write_model import read_model + + +class InputHandler: + """Base input handler class""" + + @staticmethod + def validate_path(path: str, path_type: str = "file") -> str: + """Validate path""" + if not os.path.exists(path): + raise typer.BadParameter(f"{path_type} not found: {path}") + return path + + @staticmethod + def handle_export_dir(export_dir: str, auto_cleanup: bool = False) -> str: + """Handle export directory""" + if os.path.exists(export_dir): + if auto_cleanup: + typer.echo(f"Auto-cleaning existing export directory: {export_dir}") + import shutil + + shutil.rmtree(export_dir) + os.makedirs(export_dir, exist_ok=True) + else: + typer.echo(f"Export directory '{export_dir}' already exists.") + if typer.confirm("Do you want to clean it and continue?"): + import shutil + + shutil.rmtree(export_dir) + os.makedirs(export_dir, exist_ok=True) + typer.echo(f"Cleaned export directory: {export_dir}") + else: + typer.echo("Operation cancelled.") + raise typer.Exit(0) + else: + os.makedirs(export_dir, exist_ok=True) + return export_dir + + +class ImageHandler(InputHandler): + """Single image handler""" + + @staticmethod + def process(image_path: str) -> List[str]: + """Process single image""" + InputHandler.validate_path(image_path, "Image file") + return [image_path] + + +class ImagesHandler(InputHandler): + """Image directory handler""" + + @staticmethod + def process(images_dir: str, image_extensions: str = "png,jpg,jpeg") -> List[str]: + """Process image directory""" + InputHandler.validate_path(images_dir, "Images directory") + + # Parse extensions + extensions = [ext.strip().lower() for ext in image_extensions.split(",")] + extensions = [ext if ext.startswith(".") else f".{ext}" for ext in extensions] + + # Find image files + image_files = [] + for ext in extensions: + pattern = f"*{ext}" + image_files.extend(glob.glob(os.path.join(images_dir, pattern))) + image_files.extend(glob.glob(os.path.join(images_dir, pattern.upper()))) + + image_files = sorted(list(set(image_files))) # Remove duplicates and sort + + if not image_files: + raise typer.BadParameter( + f"No image files found in {images_dir} with extensions: {extensions}" + ) + + typer.echo(f"Found {len(image_files)} images to process") + return image_files + + +class ColmapHandler(InputHandler): + """COLMAP data handler""" + + @staticmethod + def process( + colmap_dir: str, sparse_subdir: str = "" + ) -> Tuple[List[str], np.ndarray, np.ndarray]: + """Process COLMAP data""" + InputHandler.validate_path(colmap_dir, "COLMAP directory") + + # Build paths + images_dir = os.path.join(colmap_dir, "images") + if sparse_subdir: + sparse_dir = os.path.join(colmap_dir, "sparse", sparse_subdir) + else: + sparse_dir = os.path.join(colmap_dir, "sparse") + + InputHandler.validate_path(images_dir, "Images directory") + InputHandler.validate_path(sparse_dir, "Sparse reconstruction directory") + + # Load COLMAP data + typer.echo("Loading COLMAP reconstruction data...") + try: + cameras, images, points3D = read_model(sparse_dir) + + typer.echo( + f"Loaded COLMAP data: {len(cameras)} cameras, {len(images)} images, " + f"{len(points3D)} 3D points." + ) + + # Get image files and pose data + image_files = [] + extrinsics = [] + intrinsics = [] + + for image_id, image_data in images.items(): + image_name = image_data.name + image_path = os.path.join(images_dir, image_name) + + if os.path.exists(image_path): + image_files.append(image_path) + + # Get camera parameters + camera = cameras[image_data.camera_id] + + # Convert quaternion to rotation matrix + R = image_data.qvec2rotmat() + t = image_data.tvec + + # Create extrinsic matrix (world to camera) + extrinsic = np.eye(4) + extrinsic[:3, :3] = R + extrinsic[:3, 3] = t + extrinsics.append(extrinsic) + + # Create intrinsic matrix + if camera.model == "PINHOLE": + fx, fy, cx, cy = camera.params + elif camera.model == "SIMPLE_PINHOLE": + f, cx, cy = camera.params + fx = fy = f + else: + # For other models, use basic pinhole approximation + fx = fy = camera.params[0] if len(camera.params) > 0 else 1000 + cx = camera.width / 2 + cy = camera.height / 2 + + intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + intrinsics.append(intrinsic) + + if not image_files: + raise typer.BadParameter("No valid images found in COLMAP data") + + typer.echo(f"Found {len(image_files)} valid images with pose data") + + return image_files, np.array(extrinsics), np.array(intrinsics) + + except Exception as e: + raise typer.BadParameter(f"Failed to load COLMAP data: {e}") + + +class VideoHandler(InputHandler): + """Video handler""" + + @staticmethod + def process(video_path: str, output_dir: str, fps: float = 1.0) -> List[str]: + """Process video, extract frames""" + InputHandler.validate_path(video_path, "Video file") + + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise typer.BadParameter(f"Cannot open video: {video_path}") + + # Get video properties + video_fps = cap.get(cv2.CAP_PROP_FPS) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = total_frames / video_fps + + # Calculate frame interval (ensure at least 1) + frame_interval = max(1, int(video_fps / fps)) + actual_fps = video_fps / frame_interval + + typer.echo(f"Video FPS: {video_fps:.2f}, Duration: {duration:.2f}s") + + # Warn if requested FPS is higher than video FPS + if fps > video_fps: + typer.echo( + f"⚠️ Warning: Requested sampling FPS ({fps:.2f}) exceeds video FPS ({video_fps:.2f})", # noqa: E501 + err=True, + ) + typer.echo( + f"⚠️ Using maximum available FPS: {actual_fps:.2f} (extracting every frame)", + err=True, + ) + + typer.echo(f"Extracting frames at {actual_fps:.2f} FPS (every {frame_interval} frame(s))") + + # Create output directory + frames_dir = os.path.join(output_dir, "input_images") + os.makedirs(frames_dir, exist_ok=True) + + frame_count = 0 + saved_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_interval == 0: + frame_path = os.path.join(frames_dir, f"{saved_count:06d}.png") + cv2.imwrite(frame_path, frame) + saved_count += 1 + + frame_count += 1 + + cap.release() + typer.echo(f"Extracted {saved_count} frames to {frames_dir}") + + # Get frame file list + frame_files = sorted( + [f for f in os.listdir(frames_dir) if f.endswith((".png", ".jpg", ".jpeg"))] + ) + if not frame_files: + raise typer.BadParameter("No frames extracted from video") + + return [os.path.join(frames_dir, f) for f in frame_files] + + +def parse_export_feat(export_feat_str: str) -> List[int]: + """Parse export_feat parameter""" + if not export_feat_str: + return [] + + try: + return [int(x.strip()) for x in export_feat_str.split(",") if x.strip()] + except ValueError: + raise typer.BadParameter( + f"Invalid export_feat format: {export_feat_str}. " + "Use comma-separated integers like '0,1,2'" + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/specs.py b/Depth-Anything-3-anysize/src/depth_anything_3/specs.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5b30255e9fd48d988ff00c896fb3dbadf197ea --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/specs.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional +import numpy as np +import torch + + +@dataclass +class Gaussians: + """3DGS parameters, all in world space""" + + means: torch.Tensor # world points, "batch gaussian dim" + scales: torch.Tensor # scales_std, "batch gaussian 3" + rotations: torch.Tensor # world_quat_wxyz, "batch gaussian 4" + harmonics: torch.Tensor # world SH, "batch gaussian 3 d_sh" + opacities: torch.Tensor # opacity | opacity SH, "batch gaussian" | "batch gaussian 1 d_sh" + + +@dataclass +class Prediction: + depth: np.ndarray # N, H, W + is_metric: int + sky: np.ndarray | None = None # N, H, W + conf: np.ndarray | None = None # N, H, W + extrinsics: np.ndarray | None = None # N, 4, 4 + intrinsics: np.ndarray | None = None # N, 3, 3 + processed_images: np.ndarray | None = None # N, H, W, 3 - processed images for visualization + gaussians: Gaussians | None = None # 3D gaussians + aux: dict[str, Any] = None # + scale_factor: Optional[float] = None # metric scale diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..42f8e6571a8a49b17e1cf85175461c75f9fb1e80 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py @@ -0,0 +1,161 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Alignment utilities for depth estimation and metric scaling. +""" + +from typing import Tuple +import torch + + +def least_squares_scale_scalar( + a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12 +) -> torch.Tensor: + """ + Compute least squares scale factor s such that a β‰ˆ s * b. + + Args: + a: First tensor + b: Second tensor + eps: Small epsilon for numerical stability + + Returns: + Scalar tensor containing the scale factor + + Raises: + ValueError: If tensors have mismatched shapes or devices + TypeError: If tensors are not floating point + """ + if a.shape != b.shape: + raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}") + if a.device != b.device: + raise ValueError(f"Device mismatch: {a.device} vs {b.device}") + if not a.is_floating_point() or not b.is_floating_point(): + raise TypeError("Tensors must be floating point type") + + # Compute dot products for least squares solution + num = torch.dot(a.reshape(-1), b.reshape(-1)) + den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps) + return num / den + + +def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """ + Compute non-sky mask from sky prediction. + + Args: + sky_prediction: Sky prediction tensor + threshold: Threshold for sky classification + + Returns: + Boolean mask where True indicates non-sky regions + """ + return sky_prediction < threshold + + +def compute_alignment_mask( + depth_conf: torch.Tensor, + non_sky_mask: torch.Tensor, + depth: torch.Tensor, + metric_depth: torch.Tensor, + median_conf: torch.Tensor, + min_depth_threshold: float = 1e-3, + min_metric_depth_threshold: float = 1e-2, +) -> torch.Tensor: + """ + Compute mask for depth alignment based on confidence and depth thresholds. + + Args: + depth_conf: Depth confidence tensor + non_sky_mask: Non-sky region mask + depth: Predicted depth tensor + metric_depth: Metric depth tensor + median_conf: Median confidence threshold + min_depth_threshold: Minimum depth threshold + min_metric_depth_threshold: Minimum metric depth threshold + + Returns: + Boolean mask for valid alignment regions + """ + return ( + (depth_conf >= median_conf) + & non_sky_mask + & (metric_depth > min_metric_depth_threshold) + & (depth > min_depth_threshold) + ) + + +def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor: + """ + Sample tensor elements for quantile computation to reduce memory usage. + + Args: + tensor: Input tensor to sample + max_samples: Maximum number of samples to take + + Returns: + Sampled tensor + """ + if tensor.numel() <= max_samples: + return tensor + + idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples] + return tensor.flatten()[idx] + + +def apply_metric_scaling( + depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0 +) -> torch.Tensor: + """ + Apply metric scaling to depth based on camera intrinsics. + + Args: + depth: Input depth tensor + intrinsics: Camera intrinsics tensor + scale_factor: Scaling factor for metric conversion + + Returns: + Scaled depth tensor + """ + focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2 + return depth * (focal_length[:, :, None, None] / scale_factor) + + +def set_sky_regions_to_max_depth( + depth: torch.Tensor, + depth_conf: torch.Tensor, + non_sky_mask: torch.Tensor, + max_depth: float = 200.0, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Set sky regions to maximum depth and high confidence. + + Args: + depth: Depth tensor + depth_conf: Depth confidence tensor + non_sky_mask: Non-sky region mask + max_depth: Maximum depth value for sky regions + + Returns: + Tuple of (updated_depth, updated_depth_conf) + """ + depth = depth.clone() + depth_conf = depth_conf.clone() + + # Set sky regions to max depth and high confidence + depth[~non_sky_mask] = max_depth + depth_conf[~non_sky_mask] = 1.0 + + return depth, depth_conf diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..b327331d9ec61a3047be3e330f41156eb124adc4 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py @@ -0,0 +1,58 @@ +import argparse + + +def parse_scalar(s): + if not isinstance(s, str): + return s + t = s.strip() + l = t.lower() + if l == "true": + return True + if l == "false": + return False + if l in ("none", "null"): + return None + try: + return int(t, 10) + except Exception: + pass + try: + return float(t) + except Exception: + return s + + +def fn_kv_csv(s: str) -> dict[str, dict[str, object]]: + """ + Parse a string of comma-separated triplets: fn:key:value + + Returns: + dict[fn_name] -> dict[key] = parsed_value + + Example: + "fn1:width:1920,fn1:height:1080,fn2:quality:0.8" + -> {"fn1": {"width": 1920, "height": 1080}, "fn2": {"quality": 0.8}} + """ + result: dict[str, dict[str, object]] = {} + if not s: + return result + + for item in s.split(","): + if not item: + continue + parts = item.split(":", 2) # allow value to contain ":" beyond first two separators + if len(parts) < 3: + raise argparse.ArgumentTypeError(f"Bad item '{item}', expected FN:KEY:VALUE") + fn, key, raw_val = parts[0], parts[1], parts[2] + # If you need to allow colons in values, join leftover parts: + # fn, key, raw_val = parts[0], parts[1], ":".join(parts[2:]) + + if not fn: + raise argparse.ArgumentTypeError(f"Bad item '{item}': empty function name") + if not key: + raise argparse.ArgumentTypeError(f"Bad item '{item}': empty key") + + val = parse_scalar(raw_val) + bucket = result.setdefault(fn, {}) + bucket[key] = val + return result diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..83624f359553908abd28af2d59f2066a7f7a7b15 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py @@ -0,0 +1,479 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from einops import einsum, rearrange, reduce + +try: + from scipy.spatial.transform import Rotation as R +except ImportError: + from depth_anything_3.utils.logger import logger + + logger.warn("Dependency 'scipy' not found. Required for interpolating camera trajectory.") + +from depth_anything_3.utils.geometry import as_homogeneous + + +@torch.no_grad() +def render_stabilization_path(poses, k_size=45): + """Rendering stabilized camera path. + poses: [batch, 4, 4] or [batch, 3, 4], + return: + smooth path: [batch 4 4]""" + num_frames = poses.shape[0] + device = poses.device + dtype = poses.dtype + + # Early exit for trivial cases + if num_frames <= 1: + return as_homogeneous(poses) + + # Make k_size safe: positive odd and not larger than num_frames + # 1) Ensure odd + if k_size < 1: + k_size = 1 + if k_size % 2 == 0: + k_size += 1 + # 2) Cap to num_frames (keep odd) + max_odd = num_frames if (num_frames % 2 == 1) else (num_frames - 1) + if max_odd < 1: + max_odd = 1 # covers num_frames == 0 theoretically + k_size = min(k_size, max_odd) + # 3) enforce a minimum of 3 when possible (for better smoothing) + if num_frames >= 3 and k_size < 3: + k_size = 3 + + input_poses = [] + for i in range(num_frames): + input_poses.append( + torch.cat([poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], dim=-1) + ) + input_poses = torch.stack(input_poses) # (num_frames, 3, 3) + + # Prepare Gaussian kernel + gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1).astype(np.float32).squeeze() + gaussian_kernel = torch.tensor(gaussian_kernel, dtype=dtype, device=device).view(1, 1, -1) + pad = k_size // 2 + + output_vectors = [] + for idx in range(3): # For r1, r2, t + vec = ( + input_poses[:, :, idx].T.unsqueeze(0).unsqueeze(0) + ) # (1, 1, 3, num_frames) -> (1, 1, 3, num_frames) + # But actually, we want (batch=3, channel=1, width=num_frames) + # So: + vec = input_poses[:, :, idx].T.unsqueeze(1) # (3, 1, num_frames) + vec_padded = F.pad(vec, (pad, pad), mode="reflect") + filtered = F.conv1d(vec_padded, gaussian_kernel) + output_vectors.append(filtered.squeeze(1).T) # (num_frames, 3) + + output_r1, output_r2, output_t = output_vectors # Each is (num_frames, 3) + + # Normalize r1 and r2 + output_r1 = output_r1 / output_r1.norm(dim=-1, keepdim=True) + output_r2 = output_r2 / output_r2.norm(dim=-1, keepdim=True) + + output_poses = [] + for i in range(num_frames): + output_r3 = torch.linalg.cross(output_r1[i], output_r2[i]) + render_pose = torch.cat( + [ + output_r1[i].unsqueeze(-1), + output_r2[i].unsqueeze(-1), + output_r3.unsqueeze(-1), + output_t[i].unsqueeze(-1), + ], + dim=-1, + ) + output_poses.append(render_pose[:3, :]) + output_poses = as_homogeneous(torch.stack(output_poses, dim=0)) + + return output_poses + + +@torch.no_grad() +def render_wander_path( + cam2world: torch.Tensor, + intrinsic: torch.Tensor, + h: int, + w: int, + num_frames: int = 120, + max_disp: float = 48.0, +): + device, dtype = cam2world.device, cam2world.dtype + fx = intrinsic[0, 0] * w + r = max_disp / fx + th = torch.linspace(0, 2.0 * torch.pi, steps=num_frames, device=device, dtype=dtype) + x = r * torch.sin(th) + yz = r * torch.cos(th) / 3.0 + T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1) + T[:, :3, 3] = torch.stack([x, yz, yz], dim=-1) * -1.0 + c2ws = cam2world.unsqueeze(0) @ T + # Start at reference pose and end back at reference pose + c2ws = torch.cat([cam2world.unsqueeze(0), c2ws, cam2world.unsqueeze(0)], dim=0) + Ks = intrinsic.unsqueeze(0).repeat(c2ws.shape[0], 1, 1) + return c2ws, Ks + + +@torch.no_grad() +def render_dolly_zoom_path( + cam2world: torch.Tensor, + intrinsic: torch.Tensor, + h: int, + w: int, + num_frames: int = 120, + max_disp: float = 0.1, + D_focus: float = 10.0, +): + device, dtype = cam2world.device, cam2world.dtype + fx0, fy0 = intrinsic[0, 0] * w, intrinsic[1, 1] * h + t = torch.linspace(0.0, 2.0, steps=num_frames, device=device, dtype=dtype) + z = 0.5 * (1.0 - torch.cos(torch.pi * t)) * max_disp + T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1) + T[:, 2, 3] = -z + c2ws = cam2world.unsqueeze(0) @ T + Df = torch.as_tensor(D_focus, device=device, dtype=dtype) + scale = (Df / (Df + z)).clamp(min=1e-6) + Ks = intrinsic.unsqueeze(0).repeat(num_frames, 1, 1) + Ks[:, 0, 0] = (fx0 * scale) / w + Ks[:, 1, 1] = (fy0 * scale) / h + return c2ws, Ks + + +@torch.no_grad() +def interpolate_intrinsics( + initial: torch.Tensor, # "*#batch 3 3" + final: torch.Tensor, # "*#batch 3 3" + t: torch.Tensor, # " time_step" +) -> torch.Tensor: # "*batch time_step 3 3" + initial = rearrange(initial, "... i j -> ... () i j") + final = rearrange(final, "... i j -> ... () i j") + t = rearrange(t, "t -> t () ()") + return initial + (final - initial) * t + + +def intersect_rays( + a_origins: torch.Tensor, # "*#batch dim" + a_directions: torch.Tensor, # "*#batch dim" + b_origins: torch.Tensor, # "*#batch dim" + b_directions: torch.Tensor, # "*#batch dim" +) -> torch.Tensor: # "*batch dim" + """Compute the least-squares intersection of rays. Uses the math from here: + https://math.stackexchange.com/a/1762491/286022 + """ + + # Broadcast and stack the tensors. + a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors( + a_origins, a_directions, b_origins, b_directions + ) + origins = torch.stack((a_origins, b_origins), dim=-2) + directions = torch.stack((a_directions, b_directions), dim=-2) + + # Compute n_i * n_i^T - eye(3) from the equation. + n = einsum(directions, directions, "... n i, ... n j -> ... n i j") + n = n - torch.eye(3, dtype=origins.dtype, device=origins.device) + + # Compute the left-hand side of the equation. + lhs = reduce(n, "... n i j -> ... i j", "sum") + + # Compute the right-hand side of the equation. + rhs = einsum(n, origins, "... n i j, ... n j -> ... n i") + rhs = reduce(rhs, "... n i -> ... i", "sum") + + # Left-matrix-multiply both sides by the inverse of lhs to find p. + return torch.linalg.lstsq(lhs, rhs).solution + + +def normalize(a: torch.Tensor) -> torch.Tensor: # "*#batch dim" -> "*#batch dim" + return a / a.norm(dim=-1, keepdim=True) + + +def generate_coordinate_frame( + y: torch.Tensor, # "*#batch 3" + z: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 3 3" + """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors.""" + y, z = torch.broadcast_tensors(y, z) + return torch.stack([y.cross(z, dim=-1), y, z], dim=-1) + + +def generate_rotation_coordinate_frame( + a: torch.Tensor, # "*#batch 3" + b: torch.Tensor, # "*#batch 3" + eps: float = 1e-4, +) -> torch.Tensor: # "*batch 3 3" + """Generate a coordinate frame where the Y direction is normal to the plane defined + by unit vectors a and b. The other axes are arbitrary.""" + device = a.device + + # Replace every entry in b that's parallel to the corresponding entry in a with an + # arbitrary vector. + b = b.detach().clone() + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device) + parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps + b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device) + + # Generate the coordinate frame. The initial cross product defines the plane. + return generate_coordinate_frame(normalize(torch.linalg.cross(a, b)), a) + + +def matrix_to_euler( + rotations: torch.Tensor, # "*batch 3 3" + pattern: str, +) -> torch.Tensor: # "*batch 3" + *batch, _, _ = rotations.shape + rotations = rotations.reshape(-1, 3, 3) + angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern) + rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3) + + +def euler_to_matrix( + rotations: torch.Tensor, # "*batch 3" + pattern: str, +) -> torch.Tensor: # "*batch 3 3" + *batch, _ = rotations.shape + rotations = rotations.reshape(-1, 3) + matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix() + rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device) + return rotations.reshape(*batch, 3, 3) + + +def extrinsics_to_pivot_parameters( + extrinsics: torch.Tensor, # "*#batch 4 4" + pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3" + pivot_point: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 5" + """Convert the extrinsics to a representation with 5 degrees of freedom: + 1. Distance from pivot point in the "X" (look cross pivot axis) direction. + 2. Distance from pivot point in the "Y" (pivot axis) direction. + 3. Distance from pivot point in the Z (look) direction + 4. Angle in plane + 5. Twist (rotation not in plane) + """ + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + # Compute the translation elements of the pivot parametrization. + translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2]) + origin = extrinsics[..., :3, 3] + delta = pivot_point - origin + translation = einsum(translation_frame, delta, "... i j, ... i -> ... j") + + # Add the rotation elements of the pivot parametrization. + inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3] + y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1) + + return torch.cat([translation, y[..., None], z[..., None]], dim=-1) + + +def pivot_parameters_to_extrinsics( + parameters: torch.Tensor, # "*#batch 5" + pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3" + pivot_point: torch.Tensor, # "*#batch 3" +) -> torch.Tensor: # "*batch 4 4" + translation, y, z = parameters.split((3, 1, 1), dim=-1) + + euler = torch.cat((y, torch.zeros_like(y), z), dim=-1) + rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ") + + # The pivot coordinate frame's Z axis is normal to the plane. + pivot_axis = pivot_coordinate_frame[..., :, 1] + + translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2]) + delta = einsum(translation_frame, translation, "... i j, ... j -> ... i") + origin = pivot_point - delta + + *batch, _ = origin.shape + extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device) + extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone() + extrinsics[..., 3, 3] = 1 + extrinsics[..., :3, :3] = rotation + extrinsics[..., :3, 3] = origin + return extrinsics + + +def interpolate_circular( + a: torch.Tensor, # "*#batch" + b: torch.Tensor, # "*#batch" + t: torch.Tensor, # "*#batch" +) -> torch.Tensor: # " *batch" + a, b, t = torch.broadcast_tensors(a, b, t) + + tau = 2 * torch.pi + a = a % tau + b = b % tau + + # Consider piecewise edge cases. + d = (b - a).abs() + a_left = a - tau + d_left = (b - a_left).abs() + a_right = a + tau + d_right = (b - a_right).abs() + use_d = (d < d_left) & (d < d_right) + use_d_left = (d_left < d_right) & (~use_d) + use_d_right = (~use_d) & (~use_d_left) + + result = a + (b - a) * t + result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left] + result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right] + + return result + + +def interpolate_pivot_parameters( + initial: torch.Tensor, # "*#batch 5" + final: torch.Tensor, # "*#batch 5" + t: torch.Tensor, # " time_step" +) -> torch.Tensor: # "*batch time_step 5" + initial = rearrange(initial, "... d -> ... () d") + final = rearrange(final, "... d -> ... () d") + t = rearrange(t, "t -> t ()") + ti, ri = initial.split((3, 2), dim=-1) + tf, rf = final.split((3, 2), dim=-1) + + t_lerp = ti + (tf - ti) * t + r_lerp = interpolate_circular(ri, rf, t) + + return torch.cat((t_lerp, r_lerp), dim=-1) + + +@torch.no_grad() +def interpolate_extrinsics( + initial: torch.Tensor, # "*#batch 4 4" + final: torch.Tensor, # "*#batch 4 4" + t: torch.Tensor, # " time_step" + eps: float = 1e-4, +) -> torch.Tensor: # "*batch time_step 4 4" + """Interpolate extrinsics by rotating around their "focus point," which is the + least-squares intersection between the look vectors of the initial and final + extrinsics. + """ + + initial = initial.type(torch.float64) + final = final.type(torch.float64) + t = t.type(torch.float64) + + # Based on the dot product between the look vectors, pick from one of two cases: + # 1. Look vectors are parallel: interpolate about their origins' midpoint. + # 3. Look vectors aren't parallel: interpolate about their focus point. + initial_look = initial[..., :3, 2] + final_look = final[..., :3, 2] + dot_products = einsum(initial_look, final_look, "... i, ... i -> ...") + parallel_mask = (dot_products.abs() - 1).abs() < eps + + # Pick focus points. + initial_origin = initial[..., :3, 3] + final_origin = final[..., :3, 3] + pivot_point = 0.5 * (initial_origin + final_origin) + pivot_point[~parallel_mask] = intersect_rays( + initial_origin[~parallel_mask], + initial_look[~parallel_mask], + final_origin[~parallel_mask], + final_look[~parallel_mask], + ) + + # Convert to pivot parameters. + pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps) + initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point) + final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point) + + # Interpolate the pivot parameters. + interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t) + + # Convert back. + return pivot_parameters_to_extrinsics( + interpolated_params.type(torch.float32), + rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32), + rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32), + ) + + +@torch.no_grad() +def generate_wobble_transformation( + radius: torch.Tensor, # "*#batch" + t: torch.Tensor, # " time_step" + num_rotations: int = 1, + scale_radius_with_t: bool = True, +) -> torch.Tensor: # "*batch time_step 4 4"]: + # Generate a translation in the image plane. + tf = torch.eye(4, dtype=torch.float32, device=t.device) + tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone() + radius = radius[..., None] + if scale_radius_with_t: + radius = radius * t + tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius + tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius + return tf + + +@torch.no_grad() +def render_wobble_inter_path( + cam2world: torch.Tensor, intr_normed: torch.Tensor, inter_len: int, n_skip: int = 3 +): + """ + cam2world: [batch, 4, 4], + intr_normed: [batch, 3, 3] + """ + frame_per_round = n_skip * inter_len + num_rotations = 1 + + t = torch.linspace(0, 1, frame_per_round, dtype=torch.float32, device=cam2world.device) + # t = (torch.cos(torch.pi * (t + 1)) + 1) / 2 + tgt_c2w_b = [] + tgt_intr_b = [] + for b_idx in range(cam2world.shape[0]): + tgt_c2w = [] + tgt_intr = [] + for cur_idx in range(0, cam2world.shape[1] - n_skip, n_skip): + origin_a = cam2world[b_idx, cur_idx, :3, 3] + origin_b = cam2world[b_idx, cur_idx + n_skip, :3, 3] + delta = (origin_a - origin_b).norm(dim=-1) + if cur_idx == 0: + delta_prev = delta + else: + delta = (delta_prev + delta) / 2 + delta_prev = delta + tf = generate_wobble_transformation( + radius=delta * 0.5, + t=t, + num_rotations=num_rotations, + scale_radius_with_t=False, + ) + cur_extrs = ( + interpolate_extrinsics( + cam2world[b_idx, cur_idx], + cam2world[b_idx, cur_idx + n_skip], + t, + ) + @ tf + ) + tgt_c2w.append(cur_extrs[(0 if cur_idx == 0 else 1) :]) + tgt_intr.append( + interpolate_intrinsics( + intr_normed[b_idx, cur_idx], + intr_normed[b_idx, cur_idx + n_skip], + t, + )[(0 if cur_idx == 0 else 1) :] + ) + tgt_c2w_b.append(torch.cat(tgt_c2w)) + tgt_intr_b.append(torch.cat(tgt_intr)) + tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4 + tgt_intr = torch.stack(tgt_intr_b) # b v 3 3 + return tgt_c2w, tgt_intr diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..f3fb4f11827c9b8f9373df7ce10141071d89300e --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +DEFAULT_MODEL = "depth-anything/DA3NESTED-GIANT-LARGE" +DEFAULT_EXPORT_DIR = "workspace/gallery/scene" +DEFAULT_GALLERY_DIR = "workspace/gallery" +DEFAULT_GRADIO_DIR = "workspace/gradio" diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3e4c657983b19a75865ad7d3329f9f037f60cd6 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.export.gs import export_to_gs_ply, export_to_gs_video + +from .colmap import export_to_colmap +from .depth_vis import export_to_depth_vis +from .feat_vis import export_to_feat_vis +from .glb import export_to_glb +from .npz import export_to_mini_npz, export_to_npz + + +def export( + prediction: Prediction, + export_format: str, + export_dir: str, + **kwargs, +): + if "-" in export_format: + export_formats = export_format.split("-") + for export_format in export_formats: + export(prediction, export_format, export_dir, **kwargs) + return # Prevent falling through to single-format handling + + if export_format == "glb": + export_to_glb(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "mini_npz": + export_to_mini_npz(prediction, export_dir) + elif export_format == "npz": + export_to_npz(prediction, export_dir) + elif export_format == "feat_vis": + export_to_feat_vis(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "depth_vis": + export_to_depth_vis(prediction, export_dir) + elif export_format == "gs_ply": + export_to_gs_ply(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "gs_video": + export_to_gs_video(prediction, export_dir, **kwargs.get(export_format, {})) + elif export_format == "colmap": + export_to_colmap(prediction, export_dir, **kwargs.get(export_format, {})) + else: + raise ValueError(f"Unsupported export format: {export_format}") + + +__all__ = [ + export, +] diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..81aa71365be3154dcfb5467d723fb35955418560 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py @@ -0,0 +1,150 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pycolmap +import cv2 as cv +import numpy as np + +from PIL import Image + +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.logger import logger + +from .glb import _depths_to_world_points_with_colors + + +def export_to_colmap( + prediction: Prediction, + export_dir: str, + image_paths: list[str], + conf_thresh_percentile: float = 40.0, + process_res_method: str = "keep", +) -> None: + # 1. Data preparation + conf_thresh = np.percentile(prediction.conf, conf_thresh_percentile) + points, colors = _depths_to_world_points_with_colors( + prediction.depth, + prediction.intrinsics, + prediction.extrinsics, # w2c + prediction.processed_images, + prediction.conf, + conf_thresh, + ) + num_points = len(points) + logger.info(f"Exporting to COLMAP with {num_points} points") + num_frames = len(prediction.processed_images) + h, w = prediction.processed_images.shape[1:3] + points_xyf = _create_xyf(num_frames, h, w) + points_xyf = points_xyf[prediction.conf >= conf_thresh] + + # 2. Set Reconstruction + reconstruction = pycolmap.Reconstruction() + + point3d_ids = [] + for vidx in range(num_points): + point3d_id = reconstruction.add_point3D(points[vidx], pycolmap.Track(), colors[vidx]) + point3d_ids.append(point3d_id) + + for fidx in range(num_frames): + orig_w, orig_h = Image.open(image_paths[fidx]).size + + intrinsic = prediction.intrinsics[fidx] + if process_res_method.endswith("resize") or process_res_method in ("keep", "original"): + intrinsic[:1] *= orig_w / w + intrinsic[1:2] *= orig_h / h + elif process_res_method == "crop": + raise NotImplementedError("COLMAP export for crop method is not implemented") + else: + raise ValueError(f"Unknown process_res_method: {process_res_method}") + + pycolmap_intri = np.array( + [intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]] + ) + + extrinsic = prediction.extrinsics[fidx] + cam_from_world = pycolmap.Rigid3d(pycolmap.Rotation3d(extrinsic[:3, :3]), extrinsic[:3, 3]) + + # set and add camera + camera = pycolmap.Camera() + camera.camera_id = fidx + 1 + camera.model = pycolmap.CameraModelId.PINHOLE + camera.width = orig_w + camera.height = orig_h + camera.params = pycolmap_intri + reconstruction.add_camera(camera) + + # set and add rig (from camera) + rig = pycolmap.Rig() + rig.rig_id = camera.camera_id + rig.add_ref_sensor(camera.sensor_id) + reconstruction.add_rig(rig) + + # set image + image = pycolmap.Image() + image.image_id = fidx + 1 + image.camera_id = camera.camera_id + + # set and add frame (from image) + frame = pycolmap.Frame() + frame.frame_id = image.image_id + frame.rig_id = camera.camera_id + frame.add_data_id(image.data_id) + frame.rig_from_world = cam_from_world + reconstruction.add_frame(frame) + + # set point2d and update track + point2d_list = [] + points_in_frame = points_xyf[:, 2].astype(np.int32) == fidx + for vidx in np.where(points_in_frame)[0]: + point2d = points_xyf[vidx][:2] + point2d[0] *= orig_w / w + point2d[1] *= orig_h / h + point3d_id = point3d_ids[vidx] + point2d_list.append(pycolmap.Point2D(point2d, point3d_id)) + reconstruction.point3D(point3d_id).track.add_element( + image.image_id, len(point2d_list) - 1 + ) + + # set and add image + image.frame_id = image.image_id + image.name = os.path.basename(image_paths[fidx]) + image.points2D = pycolmap.Point2DList(point2d_list) + reconstruction.add_image(image) + + # 3. Export + reconstruction.write(export_dir) + + +def _create_xyf(num_frames, height, width): + """ + Creates a grid of pixel coordinates and frame indices (fidx) for all frames. + """ + # Create coordinate grids for a single frame + y_grid, x_grid = np.indices((height, width), dtype=np.int32) + x_grid = x_grid[np.newaxis, :, :] + y_grid = y_grid[np.newaxis, :, :] + + # Broadcast to all frames + x_coords = np.broadcast_to(x_grid, (num_frames, height, width)) + y_coords = np.broadcast_to(y_grid, (num_frames, height, width)) + + # Create frame indices and broadcast + f_idx = np.arange(num_frames, dtype=np.int32)[:, np.newaxis, np.newaxis] + f_coords = np.broadcast_to(f_idx, (num_frames, height, width)) + + # Stack coordinates and frame indices + points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1) + + return points_xyf diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..8accc04e92985e26b8d78a56db80989be515aac7 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import imageio +import numpy as np + +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.visualize import visualize_depth + + +def export_to_depth_vis( + prediction: Prediction, + export_dir: str, +): + # Use prediction.processed_images, which is already processed image data + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + images_u8 = prediction.processed_images # (N,H,W,3) uint8 + + os.makedirs(os.path.join(export_dir, "depth_vis"), exist_ok=True) + for idx in range(prediction.depth.shape[0]): + depth_vis = visualize_depth(prediction.depth[idx]) + image_vis = images_u8[idx] + depth_vis = depth_vis.astype(np.uint8) + image_vis = image_vis.astype(np.uint8) + vis_image = np.concatenate([image_vis, depth_vis], axis=1) + save_path = os.path.join(export_dir, f"depth_vis/{idx:04d}.jpg") + imageio.imwrite(save_path, vis_image, quality=95) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..e3dc780ea509d8b5f4660212c2914d3e81f2364a --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import imageio +import numpy as np +from tqdm.auto import tqdm + +from depth_anything_3.utils.parallel_utils import async_call +from depth_anything_3.utils.pca_utils import PCARGBVisualizer + + +@async_call +def export_to_feat_vis( + prediction, + export_dir, + fps=15, +): + """Export feature visualization with PCA. + + Args: + prediction: Model prediction containing feature maps + export_dir: Directory to export results + fps: Frame rate for output video (default: 15) + """ + out_dir = os.path.join(export_dir, "feat_vis") + os.makedirs(out_dir, exist_ok=True) + + images = prediction.processed_images + for k, v in prediction.aux.items(): + if not k.startswith("feat_layer_"): + continue + os.makedirs(os.path.join(out_dir, k), exist_ok=True) + viz = PCARGBVisualizer(basis_mode="fixed", percentile_mode="global", clip_percent=10.0) + viz.fit_reference(v) + feats_vis = viz.transform_video(v) + for idx in tqdm(range(len(feats_vis))): + img = images[idx] + feat_vis = (feats_vis[idx] * 255).astype(np.uint8) + feat_vis = cv2.resize( + feat_vis, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST + ) + save_path = os.path.join(out_dir, f"{k}/{idx:06d}.jpg") + save = np.concatenate([img, feat_vis], axis=1) + imageio.imwrite(save_path, save, quality=95) + cmd = ( + "ffmpeg -loglevel error -hide_banner -y " + f"-framerate {fps} -start_number 0 " + f"-i {out_dir}/{k}/%06d.jpg " + f"-c:v libx264 -pix_fmt yuv420p " + f"{out_dir}/{k}.mp4" + ) + os.system(cmd) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py new file mode 100644 index 0000000000000000000000000000000000000000..ece1379d98fedceba03cb53ee4cb62bd49a1ae4f --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py @@ -0,0 +1,432 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import numpy as np +import trimesh + +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.logger import logger + +from .depth_vis import export_to_depth_vis + + +def set_sky_depth(prediction: Prediction, sky_mask: np.ndarray, sky_depth_def: float = 98.0): + non_sky_mask = ~sky_mask + valid_depth = prediction.depth[non_sky_mask] + if valid_depth.size > 0: + max_depth = np.percentile(valid_depth, sky_depth_def) + prediction.depth[sky_mask] = max_depth + + +def get_conf_thresh( + prediction: Prediction, + sky_mask: np.ndarray, + conf_thresh: float, + conf_thresh_percentile: float = 10.0, + ensure_thresh_percentile: float = 90.0, +): + if sky_mask is not None and (~sky_mask).sum() > 10: + conf_pixels = prediction.conf[~sky_mask] + else: + conf_pixels = prediction.conf + lower = np.percentile(conf_pixels, conf_thresh_percentile) + upper = np.percentile(conf_pixels, ensure_thresh_percentile) + conf_thresh = min(max(conf_thresh, lower), upper) + return conf_thresh + + +def export_to_glb( + prediction: Prediction, + export_dir: str, + num_max_points: int = 1_000_000, + conf_thresh: float = 1.05, + filter_black_bg: bool = False, + filter_white_bg: bool = False, + conf_thresh_percentile: float = 40.0, + ensure_thresh_percentile: float = 90.0, + sky_depth_def: float = 98.0, + show_cameras: bool = True, + camera_size: float = 0.03, + export_depth_vis: bool = True, +) -> str: + """Generate a 3D point cloud and camera wireframes and export them as a ``.glb`` file. + + The function builds a point cloud from the predicted depth maps, aligns it to the + first camera in glTF coordinates (X-right, Y-up, Z-backward), optionally draws + camera wireframes, and writes the result to ``scene.glb``. Auxiliary assets such as + depth visualizations can also be generated alongside the main export. + + Args: + prediction: Model prediction containing depth, confidence, intrinsics, extrinsics, + and pre-processed images. + export_dir: Output directory where the glTF assets will be written. + num_max_points: Maximum number of points retained after downsampling. + conf_thresh: Base confidence threshold used before percentile adjustments. + filter_black_bg: Mark near-black background pixels for removal during confidence filtering. + filter_white_bg: Mark near-white background pixels for removal during confidence filtering. + conf_thresh_percentile: Lower percentile used when adapting the confidence threshold. + ensure_thresh_percentile: Upper percentile clamp for the adaptive threshold. + sky_depth_def: Percentile used to fill sky pixels with plausible depth values. + show_cameras: Whether to render camera wireframes in the exported scene. + camera_size: Relative camera wireframe scale as a fraction of the scene diagonal. + export_depth_vis: Whether to export raster depth visualisations alongside the glTF. + + Returns: + Path to the exported ``scene.glb`` file. + """ + # 1) Use prediction.processed_images, which is already processed image data + assert ( + prediction.processed_images is not None + ), "Export to GLB: prediction.processed_images is required but not available" + assert ( + prediction.depth is not None + ), "Export to GLB: prediction.depth is required but not available" + assert ( + prediction.intrinsics is not None + ), "Export to GLB: prediction.intrinsics is required but not available" + assert ( + prediction.extrinsics is not None + ), "Export to GLB: prediction.extrinsics is required but not available" + assert ( + prediction.conf is not None + ), "Export to GLB: prediction.conf is required but not available" + logger.info(f"conf_thresh_percentile: {conf_thresh_percentile}") + logger.info(f"num max points: {num_max_points}") + logger.info(f"Exporting to GLB with num_max_points: {num_max_points}") + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + images_u8 = prediction.processed_images # (N,H,W,3) uint8 + + # 2) Sky processing (if sky_mask is provided) + if getattr(prediction, "sky_mask", None) is not None: + set_sky_depth(prediction, prediction.sky_mask, sky_depth_def) + + # 3) Confidence threshold (if no conf, then no filtering) + if filter_black_bg: + prediction.conf[(prediction.processed_images < 16).all(axis=-1)] = 1.0 + if filter_white_bg: + prediction.conf[(prediction.processed_images >= 240).all(axis=-1)] = 1.0 + conf_thr = get_conf_thresh( + prediction, + getattr(prediction, "sky_mask", None), + conf_thresh, + conf_thresh_percentile, + ensure_thresh_percentile, + ) + + # 4) Back-project to world coordinates and get colors (world frame) + points, colors = _depths_to_world_points_with_colors( + prediction.depth, + prediction.intrinsics, + prediction.extrinsics, # w2c + images_u8, + prediction.conf, + conf_thr, + ) + + # 5) Based on first camera orientation + glTF axis system, center by point cloud, + # construct alignment transform, and apply to point cloud + A = _compute_alignment_transform_first_cam_glTF_center_by_points( + prediction.extrinsics[0], points + ) # (4,4) + + if points.shape[0] > 0: + points = trimesh.transform_points(points, A) + + # 6) Clean + downsample + points, colors = _filter_and_downsample(points, colors, num_max_points) + + # 7) Assemble scene (add point cloud first) + scene = trimesh.Scene() + if scene.metadata is None: + scene.metadata = {} + scene.metadata["hf_alignment"] = A # For camera wireframes and external reuse + + if points.shape[0] > 0: + pc = trimesh.points.PointCloud(vertices=points, colors=colors) + scene.add_geometry(pc) + + # 8) Draw cameras (wireframe pyramids), using the same transform A + if show_cameras and prediction.intrinsics is not None and prediction.extrinsics is not None: + scene_scale = _estimate_scene_scale(points, fallback=1.0) + H, W = prediction.depth.shape[1:] + _add_cameras_to_scene( + scene=scene, + K=prediction.intrinsics, + ext_w2c=prediction.extrinsics, + image_sizes=[(H, W)] * prediction.depth.shape[0], + scale=scene_scale * camera_size, + ) + + # 9) Export + os.makedirs(export_dir, exist_ok=True) + out_path = os.path.join(export_dir, "scene.glb") + scene.export(out_path) + + if export_depth_vis: + export_to_depth_vis(prediction, export_dir) + os.system(f"cp -r {export_dir}/depth_vis/0000.jpg {export_dir}/scene.jpg") + return out_path + + +# ========================= +# utilities +# ========================= + + +def _as_homogeneous44(ext: np.ndarray) -> np.ndarray: + """ + Accept (4,4) or (3,4) extrinsic parameters, return (4,4) homogeneous matrix. + """ + if ext.shape == (4, 4): + return ext + if ext.shape == (3, 4): + H = np.eye(4, dtype=ext.dtype) + H[:3, :4] = ext + return H + raise ValueError(f"extrinsic must be (4,4) or (3,4), got {ext.shape}") + + +def _depths_to_world_points_with_colors( + depth: np.ndarray, + K: np.ndarray, + ext_w2c: np.ndarray, + images_u8: np.ndarray, + conf: np.ndarray | None, + conf_thr: float, +) -> tuple[np.ndarray, np.ndarray]: + """ + For each frame, transform (u,v,1) through K^{-1} to get rays, + multiply by depth to camera frame, then use (w2c)^{-1} to transform to world frame. + Simultaneously extract colors. + """ + N, H, W = depth.shape + us, vs = np.meshgrid(np.arange(W), np.arange(H)) + ones = np.ones_like(us) + pix = np.stack([us, vs, ones], axis=-1).reshape(-1, 3) # (H*W,3) + + pts_all, col_all = [], [] + + for i in range(N): + d = depth[i] # (H,W) + valid = np.isfinite(d) & (d > 0) + if conf is not None: + valid &= conf[i] >= conf_thr + if not np.any(valid): + continue + + d_flat = d.reshape(-1) + vidx = np.flatnonzero(valid.reshape(-1)) + + K_inv = np.linalg.inv(K[i]) # (3,3) + c2w = np.linalg.inv(_as_homogeneous44(ext_w2c[i])) # (4,4) + + rays = K_inv @ pix[vidx].T # (3,M) + Xc = rays * d_flat[vidx][None, :] # (3,M) + Xc_h = np.vstack([Xc, np.ones((1, Xc.shape[1]))]) + Xw = (c2w @ Xc_h)[:3].T.astype(np.float32) # (M,3) + + cols = images_u8[i].reshape(-1, 3)[vidx].astype(np.uint8) # (M,3) + + pts_all.append(Xw) + col_all.append(cols) + + if len(pts_all) == 0: + return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8) + + return np.concatenate(pts_all, 0), np.concatenate(col_all, 0) + + +def _filter_and_downsample(points: np.ndarray, colors: np.ndarray, num_max: int): + if points.shape[0] == 0: + return points, colors + finite = np.isfinite(points).all(axis=1) + points, colors = points[finite], colors[finite] + if points.shape[0] > num_max: + idx = np.random.choice(points.shape[0], num_max, replace=False) + points, colors = points[idx], colors[idx] + return points, colors + + +def _estimate_scene_scale(points: np.ndarray, fallback: float = 1.0) -> float: + if points.shape[0] < 2: + return fallback + lo = np.percentile(points, 5, axis=0) + hi = np.percentile(points, 95, axis=0) + diag = np.linalg.norm(hi - lo) + return float(diag if np.isfinite(diag) and diag > 0 else fallback) + + +def _compute_alignment_transform_first_cam_glTF_center_by_points( + ext_w2c0: np.ndarray, + points_world: np.ndarray, +) -> np.ndarray: + """Computes the transformation matrix to align the scene with glTF standards. + + This function calculates a 4x4 homogeneous matrix that centers the scene's + point cloud and transforms its coordinate system from the computer vision (CV) + standard to the glTF standard. + + The transformation process involves three main steps: + 1. **Initial Alignment**: Orients the world coordinate system to match the + first camera's view (x-right, y-down, z-forward). + 2. **Coordinate System Conversion**: Converts the CV camera frame to the + glTF frame (x-right, y-up, z-backward) by flipping the Y and Z axes. + 3. **Centering**: Translates the entire scene so that the median of the + point cloud becomes the new origin (0,0,0). + + Returns: + A 4x4 homogeneous transformation matrix (torch.Tensor or np.ndarray) + that applies these transformations. A: X' = A @ [X;1] + """ + + w2c0 = _as_homogeneous44(ext_w2c0).astype(np.float64) + + # CV -> glTF axis transformation + M = np.eye(4, dtype=np.float64) + M[1, 1] = -1.0 # flip Y + M[2, 2] = -1.0 # flip Z + + # Don't center first + A_no_center = M @ w2c0 + + # Calculate point cloud center in new coordinate system (use median to resist outliers) + if points_world.shape[0] > 0: + pts_tmp = trimesh.transform_points(points_world, A_no_center) + center = np.median(pts_tmp, axis=0) + else: + center = np.zeros(3, dtype=np.float64) + + T_center = np.eye(4, dtype=np.float64) + T_center[:3, 3] = -center + + A = T_center @ A_no_center + return A + + +def _add_cameras_to_scene( + scene: trimesh.Scene, + K: np.ndarray, + ext_w2c: np.ndarray, + image_sizes: list[tuple[int, int]], + scale: float, +) -> None: + """Draws camera frustums to visualize their position and orientation. + + This function renders each camera as a wireframe pyramid, originating from + the camera's center and extending to the corners of its imaging plane. + + It reads the 'hf_alignment' metadata from the scene to ensure the + wireframes are correctly aligned with the 3D point cloud. + """ + N = K.shape[0] + if N == 0: + return + + # Alignment matrix consistent with point cloud (use identity matrix if missing) + A = None + try: + A = scene.metadata.get("hf_alignment", None) if scene.metadata else None + except Exception: + A = None + if A is None: + A = np.eye(4, dtype=np.float64) + + for i in range(N): + H, W = image_sizes[i] + segs = _camera_frustum_lines(K[i], ext_w2c[i], W, H, scale) # (8,2,3) world frame + # Apply unified transformation + segs = trimesh.transform_points(segs.reshape(-1, 3), A).reshape(-1, 2, 3) + path = trimesh.load_path(segs) + color = _index_color_rgb(i, N) + if hasattr(path, "colors"): + path.colors = np.tile(color, (len(path.entities), 1)) + scene.add_geometry(path) + + +def _camera_frustum_lines( + K: np.ndarray, ext_w2c: np.ndarray, W: int, H: int, scale: float +) -> np.ndarray: + corners = np.array( + [ + [0, 0, 1.0], + [W - 1, 0, 1.0], + [W - 1, H - 1, 1.0], + [0, H - 1, 1.0], + ], + dtype=float, + ) # (4,3) + + K_inv = np.linalg.inv(K) + c2w = np.linalg.inv(_as_homogeneous44(ext_w2c)) + + # camera center in world + Cw = (c2w @ np.array([0, 0, 0, 1.0]))[:3] + + # rays -> z=1 plane points (camera frame) + rays = (K_inv @ corners.T).T + z = rays[:, 2:3] + z[z == 0] = 1.0 + plane_cam = (rays / z) * scale # (4,3) + + # to world + plane_w = [] + for p in plane_cam: + pw = (c2w @ np.array([p[0], p[1], p[2], 1.0]))[:3] + plane_w.append(pw) + plane_w = np.stack(plane_w, 0) # (4,3) + + segs = [] + # center to corners + for k in range(4): + segs.append(np.stack([Cw, plane_w[k]], 0)) + # rectangle edges + order = [0, 1, 2, 3, 0] + for a, b in zip(order[:-1], order[1:]): + segs.append(np.stack([plane_w[a], plane_w[b]], 0)) + + return np.stack(segs, 0) # (8,2,3) + + +def _index_color_rgb(i: int, n: int) -> np.ndarray: + h = (i + 0.5) / max(n, 1) + s, v = 0.85, 0.95 + r, g, b = _hsv_to_rgb(h, s, v) + return (np.array([r, g, b]) * 255).astype(np.uint8) + + +def _hsv_to_rgb(h: float, s: float, v: float) -> tuple[float, float, float]: + i = int(h * 6.0) + f = h * 6.0 - i + p = v * (1.0 - s) + q = v * (1.0 - f * s) + t = v * (1.0 - (1.0 - f) * s) + i = i % 6 + if i == 0: + r, g, b = v, t, p + elif i == 1: + r, g, b = q, v, p + elif i == 2: + r, g, b = p, v, t + elif i == 3: + r, g, b = p, q, v + elif i == 4: + r, g, b = t, p, v + else: + r, g, b = v, p, q + return r, g, b diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py new file mode 100644 index 0000000000000000000000000000000000000000..90077cf25651c7977c1c1da320b7c0931fdb18ba --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py @@ -0,0 +1,154 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Literal, Optional +import moviepy.editor as mpy +import torch + +from depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.gsply_helpers import save_gaussian_ply +from depth_anything_3.utils.layout_helpers import hcat, vcat +from depth_anything_3.utils.visualize import vis_depth_map_tensor + +VIDEO_QUALITY_MAP = { + "low": {"crf": "28", "preset": "veryfast"}, + "medium": {"crf": "23", "preset": "medium"}, + "high": {"crf": "18", "preset": "slow"}, +} + + +def export_to_gs_ply( + prediction: Prediction, + export_dir: str, + gs_views_interval: Optional[ + int + ] = 1, # export GS every N views, useful for extremely dense inputs +): + gs_world = prediction.gaussians + pred_depth = torch.from_numpy(prediction.depth).unsqueeze(-1).to(gs_world.means) # v h w 1 + idx = 0 + os.makedirs(os.path.join(export_dir, "gs_ply"), exist_ok=True) + save_path = os.path.join(export_dir, f"gs_ply/{idx:04d}.ply") + if gs_views_interval is None: # select around 12 views in total + gs_views_interval = max(pred_depth.shape[0] // 12, 1) + save_gaussian_ply( + gaussians=gs_world, + save_path=save_path, + ctx_depth=pred_depth, + shift_and_scale=False, + save_sh_dc_only=True, + gs_views_interval=gs_views_interval, + inv_opacity=True, + prune_by_depth_percent=0.9, + prune_border_gs=True, + match_3dgs_mcmc_dev=False, + ) + + +def export_to_gs_video( + prediction: Prediction, + export_dir: str, + extrinsics: Optional[torch.Tensor] = None, # render views' world2cam, "b v 4 4" + intrinsics: Optional[torch.Tensor] = None, # render views' unnormed intrinsics, "b v 3 3" + out_image_hw: Optional[tuple[int, int]] = None, # render views' resolution, (h, w) + chunk_size: Optional[int] = 4, + trj_mode: Literal[ + "original", + "smooth", + "interpolate", + "interpolate_smooth", + "wander", + "dolly_zoom", + "extend", + "wobble_inter", + ] = "extend", + color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+ED", + vis_depth: Optional[Literal["hcat", "vcat"]] = "hcat", + enable_tqdm: Optional[bool] = True, + output_name: Optional[str] = None, + video_quality: Literal["low", "medium", "high"] = "high", +) -> None: + gs_world = prediction.gaussians + # if target poses are not provided, render the (smooth/interpolate) input poses + if extrinsics is not None: + tgt_extrs = extrinsics + else: + tgt_extrs = torch.from_numpy(prediction.extrinsics).unsqueeze(0).to(gs_world.means) + if prediction.is_metric: + scale_factor = prediction.scale_factor + if scale_factor is not None: + tgt_extrs[:, :, :3, 3] /= scale_factor + tgt_intrs = ( + intrinsics + if intrinsics is not None + else torch.from_numpy(prediction.intrinsics).unsqueeze(0).to(gs_world.means) + ) + # if render resolution is not provided, render the input ones + if out_image_hw is not None: + H, W = out_image_hw + else: + H, W = prediction.depth.shape[-2:] + # if single views, render wander trj + if tgt_extrs.shape[1] <= 1: + trj_mode = "wander" + # trj_mode = "dolly_zoom" + + color, depth = run_renderer_in_chunk_w_trj_mode( + gaussians=gs_world, + extrinsics=tgt_extrs, + intrinsics=tgt_intrs, + image_shape=(H, W), + chunk_size=chunk_size, + trj_mode=trj_mode, + use_sh=True, + color_mode=color_mode, + enable_tqdm=enable_tqdm, + ) + + # save as video + ffmpeg_params = [ + "-crf", + VIDEO_QUALITY_MAP[video_quality]["crf"], + "-preset", + VIDEO_QUALITY_MAP[video_quality]["preset"], + "-pix_fmt", + "yuv420p", + ] # best compatibility + + os.makedirs(os.path.join(export_dir, "gs_video"), exist_ok=True) + for idx in range(color.shape[0]): + video_i = color[idx] + if vis_depth is not None: + depth_i = vis_depth_map_tensor(depth[0]) + cat_fn = hcat if vis_depth == "hcat" else vcat + video_i = torch.stack([cat_fn(c, d) for c, d in zip(video_i, depth_i)]) + frames = list( + (video_i.clamp(0, 1) * 255).byte().permute(0, 2, 3, 1).cpu().numpy() + ) # T x H x W x C, uint8, numpy() + + fps = 24 + clip = mpy.ImageSequenceClip(frames, fps=fps) + output_name = f"{idx:04d}_{trj_mode}" if output_name is None else output_name + save_path = os.path.join(export_dir, f"gs_video/{output_name}.mp4") + # clip.write_videofile(save_path, codec="libx264", audio=False, bitrate="4000k") + clip.write_videofile( + save_path, + codec="libx264", + audio=False, + fps=fps, + ffmpeg_params=ffmpeg_params, + ) + return diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py new file mode 100644 index 0000000000000000000000000000000000000000..35ff250571b6cbc8bc4175add37011ce29db116d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np + +from depth_anything_3.specs import Prediction +from depth_anything_3.utils.parallel_utils import async_call + + +@async_call +def export_to_npz( + prediction: Prediction, + export_dir: str, +): + output_file = os.path.join(export_dir, "exports", "npz", "results.npz") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + # Use prediction.processed_images, which is already processed image data + if prediction.processed_images is None: + raise ValueError("prediction.processed_images is required but not available") + + image = prediction.processed_images # (N,H,W,3) uint8 + + # Build save dict with only non-None values + save_dict = { + "image": image, + "depth": np.round(prediction.depth, 6), + } + + if prediction.conf is not None: + save_dict["conf"] = np.round(prediction.conf, 2) + if prediction.extrinsics is not None: + save_dict["extrinsics"] = prediction.extrinsics + if prediction.intrinsics is not None: + save_dict["intrinsics"] = prediction.intrinsics + + # aux = {k: np.round(v, 4) for k, v in prediction.aux.items()} + np.savez_compressed(output_file, **save_dict) + + +@async_call +def export_to_mini_npz( + prediction: Prediction, + export_dir: str, +): + output_file = os.path.join(export_dir, "exports", "mini_npz", "results.npz") + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + # Build save dict with only non-None values + save_dict = { + "depth": np.round(prediction.depth, 6), + } + + if prediction.conf is not None: + save_dict["conf"] = np.round(prediction.conf, 2) + if prediction.extrinsics is not None: + save_dict["extrinsics"] = prediction.extrinsics + if prediction.intrinsics is not None: + save_dict["intrinsics"] = prediction.intrinsics + + np.savez_compressed(output_file, **save_dict) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..81f45fb563ce595bf547bebe829c9b83eb175f1c --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + + +def _denorm_and_to_uint8(image_tensor: torch.Tensor) -> np.ndarray: + """Denormalize to [0,255] and output (N, H, W, 3) uint8.""" + resnet_mean = torch.tensor( + [0.485, 0.456, 0.406], dtype=image_tensor.dtype, device=image_tensor.device + ) + resnet_std = torch.tensor( + [0.229, 0.224, 0.225], dtype=image_tensor.dtype, device=image_tensor.device + ) + img = image_tensor * resnet_std[None, :, None, None] + resnet_mean[None, :, None, None] + img = torch.clamp(img, 0.0, 1.0) + img = (img.permute(0, 2, 3, 1).cpu().numpy() * 255.0).round().astype(np.uint8) # (N,H,W,3) + return img diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..a88289eb0243a3b337a06933922fd6c038b54b00 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py @@ -0,0 +1,349 @@ +# flake8: noqa: F722 +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import SimpleNamespace +from typing import Optional +import numpy as np +import torch +import torch.nn.functional as F +from einops import einsum + + +def as_homogeneous(ext): + """ + Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix. + Supports torch.Tensor or np.ndarray. + """ + if isinstance(ext, torch.Tensor): + # If already in homogeneous form + if ext.shape[-2:] == (4, 4): + return ext + elif ext.shape[-2:] == (3, 4): + # Create a new homogeneous matrix + ones = torch.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return torch.cat([ext, ones], dim=-2) + else: + raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}") + + elif isinstance(ext, np.ndarray): + if ext.shape[-2:] == (4, 4): + return ext + elif ext.shape[-2:] == (3, 4): + ones = np.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return np.concatenate([ext, ones], axis=-2) + else: + raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}") + + else: + raise TypeError("Input must be a torch.Tensor or np.ndarray.") + + +@torch.jit.script +def affine_inverse(A: torch.Tensor): + R = A[..., :3, :3] # ..., 3, 3 + T = A[..., :3, 3:] # ..., 3, 1 + P = A[..., 3:, :] # ..., 1, 4 + return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) + + +def transpose_last_two_axes(arr): + """ + for np < 2 + """ + if arr.ndim < 2: + return arr + axes = list(range(arr.ndim)) + # swap the last two + axes[-2], axes[-1] = axes[-1], axes[-2] + return arr.transpose(axes) + + +def affine_inverse_np(A: np.ndarray): + R = A[..., :3, :3] + T = A[..., :3, 3:] + P = A[..., 3:, :] + return np.concatenate( + [ + np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1), + P, + ], + axis=-2, + ) + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """ + Quaternion Order: XYZW or say ijkr, scalar-last + + Convert rotations given as quaternions to rotation matrices. + Args: + quaternions: quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + i, j, k, r = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part last, as tensor of shape (..., 4). + Quaternion Order: XYZW or say ijkr, scalar-last + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and + # `int`. + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + + # Convert from rijk to ijkr + out = out[..., [1, 2, 3, 0]] + + out = standardize_quaternion(out) + + return out + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part last, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def sample_image_grid( + shape: tuple[int, ...], + device: torch.device = torch.device("cpu"), +) -> tuple[ + torch.Tensor, # float coordinates (xy indexing), "*shape dim" + torch.Tensor, # integer indices (ij indexing), "*shape dim" +]: + """Get normalized (range 0 to 1) coordinates and integer indices for an image.""" + + # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a + # (row, col) coordinate. + indices = [torch.arange(length, device=device) for length in shape] + stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1) + + # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case, + # each entry is an (x, y) coordinate. + coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)] + coordinates = reversed(coordinates) + coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1) + + return coordinates, stacked_indices + + +def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" + """Convert batched points (xyz) to (xyz1).""" + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1" + """Convert batched vectors (xyz) to (xyz0).""" + return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1) + + +def transform_rigid( + homogeneous_coordinates: torch.Tensor, # "*#batch dim" + transformation: torch.Tensor, # "*#batch dim dim" +) -> torch.Tensor: # "*batch dim" + """Apply a rigid-body transformation to points or vectors.""" + return einsum( + transformation, + homogeneous_coordinates.to(transformation.dtype), + "... i j, ... j -> ... i", + ) + + +def transform_cam2world( + homogeneous_coordinates: torch.Tensor, # "*#batch dim" + extrinsics: torch.Tensor, # "*#batch dim dim" +) -> torch.Tensor: # "*batch dim" + """Transform points from 3D camera coordinates to 3D world coordinates.""" + return transform_rigid(homogeneous_coordinates, extrinsics) + + +def unproject( + coordinates: torch.Tensor, # "*#batch dim" + z: torch.Tensor, # "*#batch" + intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" +) -> torch.Tensor: # "*batch dim+1" + """Unproject 2D camera coordinates with the given Z values.""" + + # Apply the inverse intrinsics to the coordinates. + coordinates = homogenize_points(coordinates) + ray_directions = einsum( + intrinsics.float().inverse().to(intrinsics), + coordinates.to(intrinsics.dtype), + "... i j, ... j -> ... i", + ) + + # Apply the supplied depth values. + return ray_directions * z[..., None] + + +def get_world_rays( + coordinates: torch.Tensor, # "*#batch dim" + extrinsics: torch.Tensor, # "*#batch dim+2 dim+2" + intrinsics: torch.Tensor, # "*#batch dim+1 dim+1" +) -> tuple[ + torch.Tensor, # origins, "*batch dim+1" + torch.Tensor, # directions, "*batch dim+1" +]: + # Get camera-space ray directions. + directions = unproject( + coordinates, + torch.ones_like(coordinates[..., 0]), + intrinsics, + ) + directions = directions / directions.norm(dim=-1, keepdim=True) + + # Transform ray directions to world coordinates. + directions = homogenize_vectors(directions) + directions = transform_cam2world(directions, extrinsics)[..., :-1] + + # Tile the ray origins to have the same shape as the ray directions. + origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape) + + return origins, directions + + +def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2" + intrinsics_inv = intrinsics.float().inverse().to(intrinsics) + + def process_vector(vector): + vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device) + vector = einsum(intrinsics_inv, vector, "b i j, j -> b i") + return vector / vector.norm(dim=-1, keepdim=True) + + left = process_vector([0, 0.5, 1]) + right = process_vector([1, 0.5, 1]) + top = process_vector([0.5, 0, 1]) + bottom = process_vector([0.5, 1, 1]) + fov_x = (left * right).sum(dim=-1).acos() + fov_y = (top * bottom).sum(dim=-1).acos() + return torch.stack((fov_x, fov_y), dim=-1) + + +def map_pdf_to_opacity( + pdf: torch.Tensor, # " *batch" + global_step: int = 0, + opacity_mapping: Optional[dict] = None, +) -> torch.Tensor: # " *batch" + # https://www.desmos.com/calculator/opvwti3ba9 + + # Figure out the exponent. + if opacity_mapping is not None: + cfg = SimpleNamespace(**opacity_mapping) + x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial) + else: + x = 0.0 + exponent = 2**x + + # Map the probability density to an opacity. + return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent)) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..5733009e4e91ad80ab59179c80e2df8a0430ab5f --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py @@ -0,0 +1,173 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Optional +import numpy as np +import torch +from einops import rearrange, repeat +from plyfile import PlyData, PlyElement +from torch import Tensor + +from depth_anything_3.specs import Gaussians + + +def construct_list_of_attributes(num_rest: int) -> list[str]: + attributes = ["x", "y", "z", "nx", "ny", "nz"] + for i in range(3): + attributes.append(f"f_dc_{i}") + for i in range(num_rest): + attributes.append(f"f_rest_{i}") + attributes.append("opacity") + for i in range(3): + attributes.append(f"scale_{i}") + for i in range(4): + attributes.append(f"rot_{i}") + return attributes + + +def export_ply( + means: Tensor, # "gaussian 3" + scales: Tensor, # "gaussian 3" + rotations: Tensor, # "gaussian 4" + harmonics: Tensor, # "gaussian 3 d_sh" + opacities: Tensor, # "gaussian" + path: Path, + shift_and_scale: bool = False, + save_sh_dc_only: bool = True, + match_3dgs_mcmc_dev: Optional[bool] = False, +): + if shift_and_scale: + # Shift the scene so that the median Gaussian is at the origin. + means = means - means.median(dim=0).values + + # Rescale the scene so that most Gaussians are within range [-1, 1]. + scale_factor = means.abs().quantile(0.95, dim=0).max() + means = means / scale_factor + scales = scales / scale_factor + + rotations = rotations.detach().cpu().numpy() + + # Since current model use SH_degree = 4, + # which require large memory to store, we can only save the DC band to save memory. + f_dc = harmonics[..., 0] + f_rest = harmonics[..., 1:].flatten(start_dim=1) + + if match_3dgs_mcmc_dev: + sh_degree = 3 + n_rest = 3 * (sh_degree + 1) ** 2 - 3 + f_rest = repeat( + torch.zeros_like(harmonics[..., :1]), "... i -> ... (n i)", n=(n_rest // 3) + ).flatten(start_dim=1) + dtype_full = [ + (attribute, "f4") + for attribute in construct_list_of_attributes(num_rest=n_rest) + if attribute not in ("nx", "ny", "nz") + ] + else: + dtype_full = [ + (attribute, "f4") + for attribute in construct_list_of_attributes( + 0 if save_sh_dc_only else f_rest.shape[1] + ) + ] + elements = np.empty(means.shape[0], dtype=dtype_full) + attributes = [ + means.detach().cpu().numpy(), + torch.zeros_like(means).detach().cpu().numpy(), + f_dc.detach().cpu().contiguous().numpy(), + f_rest.detach().cpu().contiguous().numpy(), + opacities[..., None].detach().cpu().numpy(), + scales.log().detach().cpu().numpy(), + rotations, + ] + if match_3dgs_mcmc_dev: + attributes.pop(1) # dummy normal is not needed + elif save_sh_dc_only: + attributes.pop(3) # remove f_rest from attributes + + attributes = np.concatenate(attributes, axis=1) + elements[:] = list(map(tuple, attributes)) + path.parent.mkdir(exist_ok=True, parents=True) + PlyData([PlyElement.describe(elements, "vertex")]).write(path) + + +def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + + +def save_gaussian_ply( + gaussians: Gaussians, + save_path: str, + ctx_depth: torch.Tensor, # depth of input views; for getting shape and filtering, "v h w 1" + shift_and_scale: bool = False, + save_sh_dc_only: bool = True, + gs_views_interval: int = 1, + inv_opacity: Optional[bool] = True, + prune_by_depth_percent: Optional[float] = 1.0, + prune_border_gs: Optional[bool] = True, + match_3dgs_mcmc_dev: Optional[bool] = False, +): + b = gaussians.means.shape[0] + assert b == 1, "must set batch_size=1 when exporting 3D gaussians" + src_v, out_h, out_w, _ = ctx_depth.shape + + # extract gs params + world_means = gaussians.means + world_shs = gaussians.harmonics + world_rotations = gaussians.rotations + gs_scales = gaussians.scales + gs_opacities = inverse_sigmoid(gaussians.opacities) if inv_opacity else gaussians.opacities + + # Create a mask to filter the Gaussians. + + # TODO: prune the sky region here + + # throw away Gaussians at the borders, since they're generally of lower quality. + if prune_border_gs: + mask = torch.zeros_like(ctx_depth, dtype=torch.bool) + gstrim_h = int(8 / 256 * out_h) + gstrim_w = int(8 / 256 * out_w) + mask[:, gstrim_h:-gstrim_h, gstrim_w:-gstrim_w, :] = 1 + else: + mask = torch.ones_like(ctx_depth, dtype=torch.bool) + + # trim the far away point based on depth; + if prune_by_depth_percent is not None and prune_by_depth_percent < 1: + in_depths = ctx_depth + d_percentile = torch.quantile( + in_depths.view(in_depths.shape[0], -1), q=prune_by_depth_percent, dim=1 + ).view(-1, 1, 1) + d_mask = (in_depths[..., 0] <= d_percentile).unsqueeze(-1) + mask = mask & d_mask + mask = mask.squeeze(-1) # v h w + + # helper fn, must place after mask + def trim_select_reshape(element): + selected_element = rearrange( + element[0], "(v h w) ... -> v h w ...", v=src_v, h=out_h, w=out_w + ) + selected_element = selected_element[::gs_views_interval][mask[::gs_views_interval]] + return selected_element + + export_ply( + means=trim_select_reshape(world_means), + scales=trim_select_reshape(gs_scales), + rotations=trim_select_reshape(world_rotations), + harmonics=trim_select_reshape(world_shs), + opacities=trim_select_reshape(gs_opacities), + path=Path(save_path), + shift_and_scale=shift_and_scale, + save_sh_dc_only=save_sh_dc_only, + match_3dgs_mcmc_dev=match_3dgs_mcmc_dev, + ) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..83114044fa6daad78abdf96412e259c9dc8fb04d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py @@ -0,0 +1,579 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Input processor for Depth Anything 3 (parallelized). + +This version removes the square center-crop step for "*crop" methods (same as your note). +In addition, it parallelizes per-image preprocessing using the provided `parallel_execution`. +""" + +from __future__ import annotations + +from typing import Optional, Sequence, Tuple +import cv2 +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image, ImageOps + +from depth_anything_3.utils.logger import logger +from depth_anything_3.utils.parallel_utils import parallel_execution + + +class InputProcessor: + """Prepares a batch of images for model inference. + This processor converts a list of image file paths into a single, model-ready + tensor. The processing pipeline is executed in parallel across multiple workers + for efficiency. + + Pipeline: + 1) Load image and convert to RGB + 2) Boundary resize (upper/lower bound, preserving aspect ratio) + 3) Enforce divisibility by PATCH_SIZE: + - "*resize" methods: each dimension is rounded to nearest multiple + (may up/downscale a few px) + - "*crop" methods: each dimension is floored to nearest multiple via center crop + 4) Convert to tensor and apply ImageNet normalization + 5) Stack into (1, N, 3, H, W) + + Parallelization: + - Each image is processed independently in a worker. + - Order of outputs matches the input order. + """ + + NORMALIZE = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + PATCH_SIZE = 14 + + def __init__(self): + pass + + # ----------------------------- + # Public API + # ----------------------------- + def __call__( + self, + image: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None = None, + intrinsics: np.ndarray | None = None, + process_res: Optional[int] = None, + process_res_method: str = "keep", + *, + num_workers: int = 8, + print_progress: bool = False, + sequential: bool | None = None, + desc: str | None = "Preprocess", + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[dict]]: + """ + Returns: + (tensor, extrinsics_list, intrinsics_list, pad_meta) + tensor shape: (1, N, 3, H, W) + """ + sequential = self._resolve_sequential(sequential, num_workers) + exts_list, ixts_list = self._validate_and_pack_meta(image, extrinsics, intrinsics) + + results = self._run_parallel( + image=image, + exts_list=exts_list, + ixts_list=ixts_list, + process_res=process_res, + process_res_method=process_res_method, + num_workers=num_workers, + print_progress=print_progress, + sequential=sequential, + desc=desc, + ) + + proc_imgs, out_sizes, out_ixts, out_exts, pad_meta = self._unpack_results(results) + proc_imgs, out_sizes, out_ixts, pad_meta = self._unify_batch_shapes( + proc_imgs, out_sizes, out_ixts, pad_meta + ) + + batch_tensor = self._stack_batch(proc_imgs) + out_exts = ( + torch.from_numpy(np.asarray(out_exts)).float() + if out_exts is not None and out_exts[0] is not None + else None + ) + out_ixts = ( + torch.from_numpy(np.asarray(out_ixts)).float() + if out_ixts is not None and out_ixts[0] is not None + else None + ) + return (batch_tensor, out_exts, out_ixts, pad_meta) + + # ----------------------------- + # __call__ helpers + # ----------------------------- + def _resolve_sequential(self, sequential: bool | None, num_workers: int) -> bool: + return (num_workers <= 1) if sequential is None else sequential + + def _validate_and_pack_meta( + self, + images: list[np.ndarray | Image.Image | str], + extrinsics: np.ndarray | None, + intrinsics: np.ndarray | None, + ) -> tuple[list[np.ndarray | None] | None, list[np.ndarray | None] | None]: + if extrinsics is not None and len(extrinsics) != len(images): + raise ValueError("Length of extrinsics must match images when provided.") + if intrinsics is not None and len(intrinsics) != len(images): + raise ValueError("Length of intrinsics must match images when provided.") + exts_list = [e for e in extrinsics] if extrinsics is not None else None + ixts_list = [k for k in intrinsics] if intrinsics is not None else None + return exts_list, ixts_list + + def _run_parallel( + self, + *, + image: list[np.ndarray | Image.Image | str], + exts_list: list[np.ndarray | None] | None, + ixts_list: list[np.ndarray | None] | None, + process_res: int, + process_res_method: str, + num_workers: int, + print_progress: bool, + sequential: bool, + desc: str | None, + ): + results = parallel_execution( + image, + exts_list, + ixts_list, + action=self._process_one, # (img, extrinsic, intrinsic, ...) + num_processes=num_workers, + print_progress=print_progress, + sequential=sequential, + desc=desc, + process_res=process_res, + process_res_method=process_res_method, + ) + if not results: + raise RuntimeError( + "No preprocessing results returned. Check inputs and parallel_execution." + ) + return results + + def _unpack_results(self, results): + """ + results: List[ + Tuple[ + torch.Tensor, + Tuple[H, W], + Optional[np.ndarray], + Optional[np.ndarray], + dict, + ] + ] + -> processed_images, out_sizes, out_intrinsics, out_extrinsics, pad_meta + """ + try: + processed_images, out_sizes, out_intrinsics, out_extrinsics, pad_meta = zip(*results) + except Exception as e: + raise RuntimeError( + "Unexpected results structure from parallel_execution: " + f"{type(results)} / sample: {results[0]}" + ) from e + + return ( + list(processed_images), + list(out_sizes), + list(out_intrinsics), + list(out_extrinsics), + list(pad_meta), + ) + + def _unify_batch_shapes( + self, + processed_images: list[torch.Tensor], + out_sizes: list[tuple[int, int]], + out_intrinsics: list[np.ndarray | None], + pad_meta: list[dict], + ) -> tuple[list[torch.Tensor], list[tuple[int, int]], list[np.ndarray | None], list[dict]]: + """Center-crop all tensors to the smallest H, W; adjust intrinsics' cx, cy accordingly.""" + if len(set(out_sizes)) <= 1: + return processed_images, out_sizes, out_intrinsics, pad_meta + + min_h = min(h for h, _ in out_sizes) + min_w = min(w for _, w in out_sizes) + logger.warn( + f"Images in batch have different sizes {out_sizes}; " + f"center-cropping all to smallest ({min_h},{min_w})" + ) + + center_crop = T.CenterCrop((min_h, min_w)) + new_imgs, new_sizes, new_ixts, new_meta = [], [], [], [] + for img_t, (H, W), K, meta in zip(processed_images, out_sizes, out_intrinsics, pad_meta): + crop_top = max(0, (H - min_h) // 2) + crop_left = max(0, (W - min_w) // 2) + new_imgs.append(center_crop(img_t)) + new_sizes.append((min_h, min_w)) + if K is None: + new_ixts.append(None) + else: + K_adj = K.copy() + K_adj[0, 2] -= crop_left + K_adj[1, 2] -= crop_top + new_ixts.append(K_adj) + # Cropping invalidates padding meta; reset so we do not apply another crop later. + new_meta.append({"orig_size": (min_h, min_w), "pad": (0, 0, 0, 0)}) + return new_imgs, new_sizes, new_ixts, new_meta + + def _stack_batch(self, processed_images: list[torch.Tensor]) -> torch.Tensor: + return torch.stack(processed_images) + + # ----------------------------- + # Per-item worker + # ----------------------------- + def _process_one( + self, + img: np.ndarray | Image.Image | str, + extrinsic: np.ndarray | None = None, + intrinsic: np.ndarray | None = None, + *, + process_res: Optional[int], + process_res_method: str, + ) -> tuple[torch.Tensor, tuple[int, int], np.ndarray | None, np.ndarray | None, dict]: + # Load & remember original size + pil_img = self._load_image(img) + orig_w, orig_h = pil_img.size + + # Boundary resize + pil_img = self._resize_image(pil_img, process_res, process_res_method) + w, h = pil_img.size + intrinsic = self._resize_ixt(intrinsic, orig_w, orig_h, w, h) + pad_left = pad_right = pad_top = pad_bottom = 0 + + # Enforce divisibility by PATCH_SIZE + if process_res_method in ("keep", "original"): + pil_img, pad_left, pad_right, pad_top, pad_bottom = self._make_divisible_by_pad( + pil_img, self.PATCH_SIZE + ) + if any((pad_left, pad_right, pad_top, pad_bottom)): + intrinsic = self._pad_ixt(intrinsic, pad_left, pad_top) + w, h = pil_img.size + elif process_res_method.endswith("resize"): + pil_img = self._make_divisible_by_resize(pil_img, self.PATCH_SIZE) + new_w, new_h = pil_img.size + intrinsic = self._resize_ixt(intrinsic, w, h, new_w, new_h) + w, h = new_w, new_h + elif process_res_method.endswith("crop"): + pil_img = self._make_divisible_by_crop(pil_img, self.PATCH_SIZE) + new_w, new_h = pil_img.size + intrinsic = self._crop_ixt(intrinsic, w, h, new_w, new_h) + w, h = new_w, new_h + else: + raise ValueError(f"Unsupported process_res_method: {process_res_method}") + + # Convert to tensor & normalize + img_tensor = self._normalize_image(pil_img) + _, H, W = img_tensor.shape + assert (W, H) == (w, h), "Tensor size mismatch with PIL image size after processing." + + meta = { + "orig_size": (orig_h, orig_w), + "pad": (pad_top, pad_bottom, pad_left, pad_right), + } + + # Return: (img_tensor, (H, W), intrinsic, extrinsic, meta) + return img_tensor, (H, W), intrinsic, extrinsic, meta + + # ----------------------------- + # Intrinsics transforms + # ----------------------------- + def _resize_ixt( + self, + intrinsic: np.ndarray | None, + orig_w: int, + orig_h: int, + w: int, + h: int, + ) -> np.ndarray | None: + if intrinsic is None: + return None + K = intrinsic.copy() + # scale fx, cx by w ratio; fy, cy by h ratio + K[:1] *= w / float(orig_w) + K[1:2] *= h / float(orig_h) + return K + + def _crop_ixt( + self, + intrinsic: np.ndarray | None, + orig_w: int, + orig_h: int, + w: int, + h: int, + ) -> np.ndarray | None: + if intrinsic is None: + return None + K = intrinsic.copy() + crop_h = (orig_h - h) // 2 + crop_w = (orig_w - w) // 2 + K[0, 2] -= crop_w + K[1, 2] -= crop_h + return K + + def _pad_ixt( + self, + intrinsic: np.ndarray | None, + pad_left: int, + pad_top: int, + ) -> np.ndarray | None: + if intrinsic is None or (pad_left == 0 and pad_top == 0): + return intrinsic + K = intrinsic.copy() + K[0, 2] += pad_left + K[1, 2] += pad_top + return K + + # ----------------------------- + # I/O & normalization + # ----------------------------- + def _load_image(self, img: np.ndarray | Image.Image | str) -> Image.Image: + if isinstance(img, str): + return Image.open(img).convert("RGB") + elif isinstance(img, np.ndarray): + # Assume HxWxC uint8/RGB + return Image.fromarray(img).convert("RGB") + elif isinstance(img, Image.Image): + return img.convert("RGB") + else: + raise ValueError(f"Unsupported image type: {type(img)}") + + def _normalize_image(self, img: Image.Image) -> torch.Tensor: + img_tensor = T.ToTensor()(img) + return self.NORMALIZE(img_tensor) + + # ----------------------------- + # Boundary resizing + # ----------------------------- + def _resize_image( + self, img: Image.Image, target_size: Optional[int], method: str + ) -> Image.Image: + if method in ("keep", "original"): + return img + + if target_size is None or target_size <= 0: + raise ValueError( + f"process_res must be set when using '{method}'. Received: {target_size}" + ) + + if method in ("upper_bound_resize", "upper_bound_crop"): + return self._resize_longest_side(img, target_size) + elif method in ("lower_bound_resize", "lower_bound_crop"): + return self._resize_shortest_side(img, target_size) + else: + raise ValueError(f"Unsupported resize method: {method}") + + def _resize_longest_side(self, img: Image.Image, target_size: int) -> Image.Image: + w, h = img.size + longest = max(w, h) + if longest == target_size: + return img + scale = target_size / float(longest) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + def _resize_shortest_side(self, img: Image.Image, target_size: int) -> Image.Image: + w, h = img.size + shortest = min(w, h) + if shortest == target_size: + return img + scale = target_size / float(shortest) + new_w = max(1, int(round(w * scale))) + new_h = max(1, int(round(h * scale))) + interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + # ----------------------------- + # Make divisible by PATCH_SIZE + # ----------------------------- + def _make_divisible_by_crop(self, img: Image.Image, patch: int) -> Image.Image: + """ + Floor each dimension to the nearest multiple of PATCH_SIZE via center crop. + Example: 504x377 -> 504x364 + """ + w, h = img.size + new_w = (w // patch) * patch + new_h = (h // patch) * patch + if new_w == w and new_h == h: + return img + left = (w - new_w) // 2 + top = (h - new_h) // 2 + return img.crop((left, top, left + new_w, top + new_h)) + + def _make_divisible_by_resize(self, img: Image.Image, patch: int) -> Image.Image: + """ + Round each dimension to nearest multiple of PATCH_SIZE via small resize. + """ + w, h = img.size + + def nearest_multiple(x: int, p: int) -> int: + down = (x // p) * p + up = down + p + return up if abs(up - x) <= abs(x - down) else down + + new_w = max(1, nearest_multiple(w, patch)) + new_h = max(1, nearest_multiple(h, patch)) + if new_w == w and new_h == h: + return img + upscale = (new_w > w) or (new_h > h) + interpolation = cv2.INTER_CUBIC if upscale else cv2.INTER_AREA + arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation) + return Image.fromarray(arr) + + def _make_divisible_by_pad( + self, img: Image.Image, patch: int + ) -> tuple[Image.Image, int, int, int, int]: + """ + Pad each dimension up to the nearest multiple of PATCH_SIZE. + Returns: (padded_img, pad_left, pad_right, pad_top, pad_bottom) + """ + w, h = img.size + new_w = ((w + patch - 1) // patch) * patch + new_h = ((h + patch - 1) // patch) * patch + pad_w = new_w - w + pad_h = new_h - h + if pad_w == 0 and pad_h == 0: + return img, 0, 0, 0, 0 + + pad_left = pad_w // 2 + pad_right = pad_w - pad_left + pad_top = pad_h // 2 + pad_bottom = pad_h - pad_top + + padded = ImageOps.expand(img, border=(pad_left, pad_top, pad_right, pad_bottom)) + return padded, pad_left, pad_right, pad_top, pad_bottom + + +# Backward compatibility alias +InputAdapter = InputProcessor + + +# =========================== +# Minimal test runner (parallel execution) +# =========================== +if __name__ == "__main__": + """ + Minimal test suite: + - Creates pairs of images so batch shapes match. + - Tests all four process_res_methods. + - Prints fx fy cx cy IN->OUT per image. + - Includes cases with K/E provided and with None. + """ + + def fmt_k_line(K: np.ndarray | None) -> str: + if K is None: + return "None" + fx, fy, cx, cy = float(K[0, 0]), float(K[1, 1]), float(K[0, 2]), float(K[1, 2]) + return f"fx={fx:.3f} fy={fy:.3f} cx={cx:.3f} cy={cy:.3f}" + + def show_result( + tag: str, + tensor: torch.Tensor, + Ks_in: Sequence[np.ndarray | None] | None = None, + Ks_out: Sequence[np.ndarray | None] | None = None, + ): + B, N, C, H, W = tensor.shape + print(f"[{tag}] shape={tuple(tensor.shape)} HxW=({H},{W}) div14=({H%14==0},{W%14==0})") + assert H % 14 == 0 and W % 14 == 0, f"{tag}: output size not divisible by 14!" + if Ks_in is not None or Ks_out is not None: + Ks_in = Ks_in or [None] * N + Ks_out = Ks_out or [None] * N + for i in range(N): + print(f" K[{i}]: {fmt_k_line(Ks_in[i])} -> {fmt_k_line(Ks_out[i])}") + + proc = InputProcessor() + process_res = 504 + methods = ["upper_bound_resize", "upper_bound_crop", "lower_bound_resize", "lower_bound_crop"] + + # Example sizes (two orientations) + small_sizes = [(680, 1208), (1208, 680)] + large_sizes = [(1208, 680), (680, 1208)] + + def make_K(w, h, fx=1200.0, fy=1100.0): + cx, cy = w / 2.0, h / 2.0 + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + return K + + def run_suite(suite_name: str, sizes: list[tuple[int, int]]): + print(f"\n===== {suite_name} =====") + for w, h in sizes: + img = Image.new("RGB", (w, h), color=(123, 222, 100)) + batch_imgs = [img, img] + + # intrinsics / extrinsics examples + Ks_in = [make_K(w, h), make_K(w, h)] + Es_in = [np.eye(4, dtype=np.float32), np.eye(4, dtype=np.float32)] + + for m in methods: + tensor, Es_out, Ks_out = proc( + image=batch_imgs, + process_res=process_res, + process_res_method=m, + num_workers=8, + print_progress=False, + intrinsics=Ks_in, # test with non-None + extrinsics=Es_in, + ) + show_result(f"{suite_name} size=({w},{h}) | {m}", tensor, Ks_in, Ks_out) + + # Also test None path + tensor2, Es_out2, Ks_out2 = proc( + image=batch_imgs, + process_res=process_res, + process_res_method="upper_bound_resize", + num_workers=8, + intrinsics=None, + extrinsics=None, + ) + show_result( + f"{suite_name} size=({w},{h}) | upper_bound_resize | no K/E", + tensor2, + None, + Ks_out2, + ) + + run_suite("SMALL", small_sizes) + run_suite("LARGE", large_sizes) + + # Extra sanity for 504x376 + print("\n===== EXTRA sanity for 504x376 =====") + img_example = Image.new("RGB", (504, 376), color=(10, 20, 30)) + Ks_in_extra = [make_K(504, 376, fx=900.0, fy=900.0), make_K(504, 376, fx=900.0, fy=900.0)] + + out_r, _, Ks_out_r = proc( + image=[img_example, img_example], + process_res=504, + process_res_method="upper_bound_resize", + num_workers=8, + intrinsics=Ks_in_extra, + ) + out_c, _, Ks_out_c = proc( + image=[img_example, img_example], + process_res=504, + process_res_method="upper_bound_crop", + num_workers=8, + intrinsics=Ks_in_extra, + ) + _, _, _, Hr, Wr = out_r.shape + _, _, _, Hc, Wc = out_c.shape + print(f"upper_bound_resize -> ({Hr},{Wr}) (rounded to nearest multiple of 14)") + show_result("Ks after upper_bound_resize", out_r, Ks_in_extra, Ks_out_r) + print(f"upper_bound_crop -> ({Hc},{Wc}) (floored to multiple of 14)") + show_result("Ks after upper_bound_crop", out_c, Ks_in_extra, Ks_out_c) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c317eb9d596c1687b5281891a035993868cc5f8c --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Output processor for Depth Anything 3. + +This module handles model output processing, including tensor-to-numpy conversion, +batch dimension removal, and Prediction object creation. +""" + +from __future__ import annotations + +import numpy as np +import torch +from addict import Dict as AddictDict + +from depth_anything_3.specs import Prediction + + +class OutputProcessor: + """ + Output processor for converting model outputs to Prediction objects. + + Handles tensor-to-numpy conversion, batch dimension removal, + and creates structured Prediction objects with proper data types. + """ + + def __init__(self) -> None: + """Initialize the output processor.""" + + def __call__(self, model_output: dict[str, torch.Tensor]) -> Prediction: + """ + Convert model output to Prediction object. + + Args: + model_output: Model output dictionary containing depth, conf, extrinsics, intrinsics + Expected shapes: depth (B, N, 1, H, W), conf (B, N, 1, H, W), + extrinsics (B, N, 4, 4), intrinsics (B, N, 3, 3) + + Returns: + Prediction: Object containing depth estimation results with shapes: + depth (N, H, W), conf (N, H, W), extrinsics (N, 4, 4), intrinsics (N, 3, 3) + """ + # Extract data from batch dimension (B=1, N=number of images) + depth = self._extract_depth(model_output) + conf = self._extract_conf(model_output) + extrinsics = self._extract_extrinsics(model_output) + intrinsics = self._extract_intrinsics(model_output) + sky = self._extract_sky(model_output) + aux = self._extract_aux(model_output) + gaussians = model_output.get("gaussians", None) + scale_factor = model_output.get("scale_factor", None) + + return Prediction( + depth=depth, + sky=sky, + conf=conf, + extrinsics=extrinsics, + intrinsics=intrinsics, + is_metric=getattr(model_output, "is_metric", 0), + gaussians=gaussians, + aux=aux, + scale_factor=scale_factor, + ) + + def _extract_depth(self, model_output: dict[str, torch.Tensor]) -> np.ndarray: + """ + Extract depth tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Depth array with shape (N, H, W) + """ + depth = model_output["depth"].squeeze(0).squeeze(-1).cpu().numpy() # (N, H, W) + return depth + + def _extract_conf(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract confidence tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Confidence array with shape (N, H, W) or None + """ + conf = model_output.get("depth_conf", None) + if conf is not None: + conf = conf.squeeze(0).cpu().numpy() # (N, H, W) + return conf + + def _extract_extrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract extrinsics tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Extrinsics array with shape (N, 4, 4) or None + """ + extrinsics = model_output.get("extrinsics", None) + if extrinsics is not None: + extrinsics = extrinsics.squeeze(0).cpu().numpy() # (N, 4, 4) + return extrinsics + + def _extract_intrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract intrinsics tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Intrinsics array with shape (N, 3, 3) or None + """ + intrinsics = model_output.get("intrinsics", None) + if intrinsics is not None: + intrinsics = intrinsics.squeeze(0).cpu().numpy() # (N, 3, 3) + return intrinsics + + def _extract_sky(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None: + """ + Extract sky tensor from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Sky mask array with shape (N, H, W) or None + """ + sky = model_output.get("sky", None) + if sky is not None: + sky = sky.squeeze(0).cpu().numpy() >= 0.5 # (N, H, W) + return sky + + def _extract_aux(self, model_output: dict[str, torch.Tensor]) -> AddictDict: + """ + Extract auxiliary data from model output and convert to numpy. + + Args: + model_output: Model output dictionary + + Returns: + Dictionary containing auxiliary data + """ + aux = model_output.get("aux", None) + ret = AddictDict() + if aux is not None: + for k in aux.keys(): + if isinstance(aux[k], torch.Tensor): + ret[k] = aux[k].squeeze(0).cpu().numpy() + else: + ret[k] = aux[k] + return ret + + +# Backward compatibility alias +OutputAdapter = OutputProcessor diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..189c170b2007c979e580b69ca929638560923fb2 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py @@ -0,0 +1,216 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains useful layout utilities for images. They are: + +- add_border: Add a border to an image. +- cat/hcat/vcat: Join images by arranging them in a line. If the images have different + sizes, they are aligned as specified (start, end, center). Allows you to specify a gap + between images. + +Images are assumed to be float32 tensors with shape (channel, height, width). +""" + +from typing import Any, Generator, Iterable, Literal, Union +import torch +from torch import Tensor + +Alignment = Literal["start", "center", "end"] +Axis = Literal["horizontal", "vertical"] +Color = Union[ + int, + float, + Iterable[int], + Iterable[float], + Tensor, + Tensor, +] + + +def _sanitize_color(color: Color) -> Tensor: # "#channel" + # Convert tensor to list (or individual item). + if isinstance(color, torch.Tensor): + color = color.tolist() + + # Turn iterators and individual items into lists. + if isinstance(color, Iterable): + color = list(color) + else: + color = [color] + + return torch.tensor(color, dtype=torch.float32) + + +def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]: + it = iter(iterable) + yield next(it) + for item in it: + yield delimiter + yield item + + +def _get_main_dim(main_axis: Axis) -> int: + return { + "horizontal": 2, + "vertical": 1, + }[main_axis] + + +def _get_cross_dim(main_axis: Axis) -> int: + return { + "horizontal": 1, + "vertical": 2, + }[main_axis] + + +def _compute_offset(base: int, overlay: int, align: Alignment) -> slice: + assert base >= overlay + offset = { + "start": 0, + "center": (base - overlay) // 2, + "end": base - overlay, + }[align] + return slice(offset, offset + overlay) + + +def overlay( + base: Tensor, # "channel base_height base_width" + overlay: Tensor, # "channel overlay_height overlay_width" + main_axis: Axis, + main_axis_alignment: Alignment, + cross_axis_alignment: Alignment, +) -> Tensor: # "channel base_height base_width" + # The overlay must be smaller than the base. + _, base_height, base_width = base.shape + _, overlay_height, overlay_width = overlay.shape + assert base_height >= overlay_height and base_width >= overlay_width + + # Compute spacing on the main dimension. + main_dim = _get_main_dim(main_axis) + main_slice = _compute_offset( + base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment + ) + + # Compute spacing on the cross dimension. + cross_dim = _get_cross_dim(main_axis) + cross_slice = _compute_offset( + base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment + ) + + # Combine the slices and paste the overlay onto the base accordingly. + selector = [..., None, None] + selector[main_dim] = main_slice + selector[cross_dim] = cross_slice + result = base.clone() + result[selector] = overlay + return result + + +def cat( + main_axis: Axis, + *images: Iterable[Tensor], # "channel _ _" + align: Alignment = "center", + gap: int = 8, + gap_color: Color = 1, +) -> Tensor: # "channel height width" + """Arrange images in a line. The interface resembles a CSS div with flexbox.""" + device = images[0].device + gap_color = _sanitize_color(gap_color).to(device) + + # Find the maximum image side length in the cross axis dimension. + cross_dim = _get_cross_dim(main_axis) + cross_axis_length = max(image.shape[cross_dim] for image in images) + + # Pad the images. + padded_images = [] + for image in images: + # Create an empty image with the correct size. + padded_shape = list(image.shape) + padded_shape[cross_dim] = cross_axis_length + base = torch.ones(padded_shape, dtype=torch.float32, device=device) + base = base * gap_color[:, None, None] + padded_images.append(overlay(base, image, main_axis, "start", align)) + + # Intersperse separators if necessary. + if gap > 0: + # Generate a separator. + c, _, _ = images[0].shape + separator_size = [gap, gap] + separator_size[cross_dim - 1] = cross_axis_length + separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device) + separator = separator * gap_color[:, None, None] + + # Intersperse the separator between the images. + padded_images = list(_intersperse(padded_images, separator)) + + return torch.cat(padded_images, dim=_get_main_dim(main_axis)) + + +def hcat( + *images: Iterable[Tensor], # "channel _ _" + align: Literal["start", "center", "end", "top", "bottom"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "horizontal", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "top": "start", + "bottom": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def vcat( + *images: Iterable[Tensor], # "channel _ _" + align: Literal["start", "center", "end", "left", "right"] = "start", + gap: int = 8, + gap_color: Color = 1, +): + """Shorthand for a horizontal linear concatenation.""" + return cat( + "vertical", + *images, + align={ + "start": "start", + "center": "center", + "end": "end", + "left": "start", + "right": "end", + }[align], + gap=gap, + gap_color=gap_color, + ) + + +def add_border( + image: Tensor, # "channel height width" + border: int = 8, + color: Color = 1, +) -> Tensor: # "channel new_height new_width" + color = _sanitize_color(color).to(image) + c, h, w = image.shape + result = torch.empty( + (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device + ) + result[:] = color[:, None, None] + result[:, border : h + border, border : w + border] = image + return result diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0eb4f60696a085001cf4866ccfe1654170702a2d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + + +class Color: + RED = "\033[91m" + YELLOW = "\033[93m" + WHITE = "\033[97m" + GREEN = "\033[92m" + RESET = "\033[0m" + + +LOG_LEVELS = {"ERROR": 0, "WARN": 1, "INFO": 2, "DEBUG": 3} + +COLOR_MAP = {"ERROR": Color.RED, "WARN": Color.YELLOW, "INFO": Color.WHITE, "DEBUG": Color.GREEN} + + +def get_env_log_level(): + level = os.environ.get("DA3_LOG_LEVEL", "INFO").upper() + return LOG_LEVELS.get(level, LOG_LEVELS["INFO"]) + + +class Logger: + def __init__(self): + self.level = get_env_log_level() + + def log(self, level_str, *args, **kwargs): + level_key = level_str.split(":")[0].strip() + level_val = LOG_LEVELS.get(level_key) + if level_val is None: + raise ValueError(f"Unknown log level: {level_str}") + if self.level >= level_val: + color = COLOR_MAP[level_key] + msg = " ".join(str(arg) for arg in args) + + # Align log level output in square brackets + # ERROR and DEBUG are 5 characters, INFO and WARN have an extra space for alignment + tag = level_key + if tag in ("INFO", "WARN"): + tag += " " + print( + f"{color}[{tag}] {msg}{Color.RESET}", + file=sys.stderr if level_key == "ERROR" else sys.stdout, + **kwargs, + ) + + def error(self, *args, **kwargs): + self.log("ERROR:", *args, **kwargs) + + def warn(self, *args, **kwargs): + self.log("WARN:", *args, **kwargs) + + def info(self, *args, **kwargs): + self.log("INFO:", *args, **kwargs) + + def debug(self, *args, **kwargs): + self.log("DEBUG:", *args, **kwargs) + + +logger = Logger() + +__all__ = ["logger"] + +if __name__ == "__main__": + logger.info("This is an info message") + logger.warn("This is a warning message") + logger.error("This is an error message") + logger.debug("This is a debug message") diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..f5f61595d21a6c221b9be3c7954fe56ff83e5300 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py @@ -0,0 +1,128 @@ +""" +GPU memory utility helpers. + +Shared cleanup and memory checking logic used by both the backend API and +the Gradio UI to keep memory-management behavior consistent. +""" +from __future__ import annotations + +import gc + +from typing import Any, Dict, Optional + +import torch + + +def get_gpu_memory_info() -> Optional[Dict[str, Any]]: + """Return a snapshot of current GPU memory usage or None if CUDA not available. + + Keys in returned dict: total_gb, allocated_gb, reserved_gb, free_gb, utilization + """ + if not torch.cuda.is_available(): + return None + + try: + device = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(device).total_memory + allocated_memory = torch.cuda.memory_allocated(device) + reserved_memory = torch.cuda.memory_reserved(device) + free_memory = total_memory - reserved_memory + + return { + "total_gb": total_memory / 1024 ** 3, + "allocated_gb": allocated_memory / 1024 ** 3, + "reserved_gb": reserved_memory / 1024 ** 3, + "free_gb": free_memory / 1024 ** 3, + "utilization": (reserved_memory / total_memory) * 100, + } + except Exception: + return None + + +def cleanup_cuda_memory() -> None: + """Perform a robust GPU cleanup sequence. + + This includes synchronizing, emptying caches, collecting IPC handles and + running the Python garbage collector. Use this instead of a raw + ``torch.cuda.empty_cache()`` where you need reliable freeing of GPU memory + between model loads or in error handling paths. + """ + try: + if torch.cuda.is_available(): + mem_before = get_gpu_memory_info() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + # Collect cross-process cuda resources + try: + torch.cuda.ipc_collect() + except Exception: + # Older PyTorch versions or non-cuda devices may not support + # ipc_collect (no-op if not available) + pass + gc.collect() + + mem_after = get_gpu_memory_info() + if mem_before and mem_after: + freed = mem_before["reserved_gb"] - mem_after["reserved_gb"] + print( + f"CUDA cleanup: freed {freed:.2f}GB, " + f"available: {mem_after['free_gb']:.2f}GB/{mem_after['total_gb']:.2f}GB" + ) + else: + print("CUDA memory cleanup completed") + except Exception as e: + print(f"Warning: CUDA cleanup failed: {e}") + + +def check_memory_availability(required_gb: float = 2.0) -> tuple[bool, str]: + """Return whether at least ``required_gb`` seems available on the current GPU. + + The returned tuple is (is_available, message) with a human-friendly message. + """ + try: + if not torch.cuda.is_available(): + return False, "CUDA is not available" + + mem_info = get_gpu_memory_info() + if mem_info is None: + return True, "Cannot check memory, proceeding anyway" + + if mem_info["free_gb"] < required_gb: + return ( + False, + ( + f"Insufficient GPU memory: {mem_info['free_gb']:.2f}GB available, " + f"{required_gb:.2f}GB required. Total: {mem_info['total_gb']:.2f}GB, " + f"Used: {mem_info['reserved_gb']:.2f}GB ({mem_info['utilization']:.1f}%)" + ), + ) + + return ( + True, + ( + f"Memory check passed: {mem_info['free_gb']:.2f}GB available, " + f"{required_gb:.2f}GB required" + ), + ) + except Exception as e: + return True, f"Memory check failed: {e}, proceeding anyway" +def estimate_memory_requirement(num_images: int, process_res: int | None) -> float: + """Heuristic estimate for memory usage (GB) based on image count and resolution. + + This mirrors the simple policy used by the backend service so other code + (e.g., Gradio UI) can make consistent decisions when checking available + memory before loading a model or running inference. + + Args: + num_images: Number of images to process. + process_res: Processing resolution. + + Returns: + Estimated memory requirement in GB. + """ + base_memory = 2.0 + effective_res = 504 if process_res is None or process_res <= 0 else process_res + per_image_memory = (effective_res / 504) ** 2 * 0.5 + total_memory = base_memory + (num_images * per_image_memory * 0.1) + return total_memory diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d43b5bbab5a0989eae272422192103cd4d5bee --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py @@ -0,0 +1,149 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model loading and state dict conversion utilities. +""" + +from typing import Dict, Tuple +import torch + +from depth_anything_3.utils.logger import logger + + +def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert general model state dict to match current model architecture. + + Args: + state_dict: Original state dictionary + + Returns: + Converted state dictionary + """ + # Replace module prefixes + state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()} + state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()} + + # Remove camera token if present + if "model.backbone.pretrained.camera_token" in state_dict: + del state_dict["model.backbone.pretrained.camera_token"] + + # Replace camera token naming + state_dict = { + k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items() + } + + # Replace head naming + state_dict = { + k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v + for k, v in state_dict.items() + } + state_dict = { + k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items() + } + state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()} + state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()} + state_dict = { + k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items() + } + + # Replace output naming + state_dict = { + k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v + for k, v in state_dict.items() + } + state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()} + + # Update GS-DPT head naming and value + state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()} + + return state_dict + + +def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Convert metric model state dict to match current model architecture. + + Args: + state_dict: Original metric state dictionary + + Returns: + Converted state dictionary + """ + # Add module prefix for metric models + state_dict = {"module." + k: v for k, v in state_dict.items()} + return convert_general_state_dict(state_dict) + + +def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]: + """ + Load pretrained weights for a single model. + + Args: + model: Model instance to load weights into + model_path: Path to the pretrained weights + is_metric: Whether this is a metric model + + Returns: + Tuple of (missed_keys, unexpected_keys) + """ + state_dict = torch.load(model_path, map_location="cpu") + + if is_metric: + state_dict = convert_metric_state_dict(state_dict) + else: + state_dict = convert_general_state_dict(state_dict) + + missed, unexpected = model.load_state_dict(state_dict, strict=False) + logger.info("Missed keys:", missed) + logger.info("Unexpected keys:", unexpected) + + return missed, unexpected + + +def load_pretrained_nested_weights( + model, main_model_path: str, metric_model_path: str +) -> Tuple[list, list]: + """ + Load pretrained weights for a nested model with both main and metric branches. + + Args: + model: Nested model instance + main_model_path: Path to main model weights + metric_model_path: Path to metric model weights + + Returns: + Tuple of (missed_keys, unexpected_keys) + """ + # Load main model weights + state_dict0 = torch.load(main_model_path, map_location="cpu") + state_dict0 = convert_general_state_dict(state_dict0) + state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()} + + # Load metric model weights + state_dict1 = torch.load(metric_model_path, map_location="cpu") + state_dict1 = convert_metric_state_dict(state_dict1) + state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()} + + # Combine state dictionaries + combined_state_dict = state_dict0.copy() + combined_state_dict.update(state_dict1) + + missed, unexpected = model.load_state_dict(combined_state_dict, strict=False) + + print("Missed keys:", missed) + print("Unexpected keys:", unexpected) + + return missed, unexpected diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff108e95d205097f9f2012ab87ad9e265d58d8d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +from functools import wraps +from multiprocessing.pool import ThreadPool +from threading import Thread +from typing import Callable, Dict, List +import imageio +from tqdm import tqdm + + +def async_call_func(func): + @wraps(func) + async def wrapper(*args, **kwargs): + loop = asyncio.get_event_loop() + # Use run_in_executor to run the blocking function in a separate thread + return await loop.run_in_executor(None, func, *args, **kwargs) + + return wrapper + + +slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [ + slice(chunk_index, chunk_index + chunk_size) +] + + +def async_call(fn): + def wrapper(*args, **kwargs): + Thread(target=fn, args=args, kwargs=kwargs).start() + + return wrapper + + +def _save_image_impl(save_img, save_path): + """Common implementation for saving images synchronously or asynchronously""" + os.makedirs(os.path.dirname(save_path), exist_ok=True) + imageio.imwrite(save_path, save_img) + + +@async_call +def save_image_async(save_img, save_path): + """Save image asynchronously""" + _save_image_impl(save_img, save_path) + + +def save_image(save_img, save_path): + """Save image synchronously""" + _save_image_impl(save_img, save_path) + + +def parallel_execution( + *args, + action: Callable, + num_processes=32, + print_progress=False, + sequential=False, + async_return=False, + desc=None, + **kwargs, +): + # Partially copy from EasyVolumetricVideo (parallel_execution) + # NOTE: we expect first arg / or kwargs to be distributed + # NOTE: print_progress arg is reserved. + # `*args` packs all positional arguments passed to the function into a tuple + args = list(args) + + def get_length(args: List, kwargs: Dict): + for a in args: + if isinstance(a, list): + return len(a) + for v in kwargs.values(): + if isinstance(v, list): + return len(v) + raise NotImplementedError + + def get_action_args(length: int, args: List, kwargs: Dict, i: int): + action_args = [ + (arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args + ] + # TODO: Support all types of iterable + action_kwargs = { + key: ( + kwargs[key][i] + if isinstance(kwargs[key], list) and len(kwargs[key]) == length + else kwargs[key] + ) + for key in kwargs + } + return action_args, action_kwargs + + if not sequential: + # Create ThreadPool + pool = ThreadPool(processes=num_processes) + + # Spawn threads + results = [] + asyncs = [] + length = get_length(args, kwargs) + for i in range(length): + action_args, action_kwargs = get_action_args(length, args, kwargs, i) + async_result = pool.apply_async(action, action_args, action_kwargs) + asyncs.append(async_result) + + # Join threads and get return values + if not async_return: + for async_result in tqdm(asyncs, desc=desc, disable=not print_progress): + results.append(async_result.get()) # will sync the corresponding thread + pool.close() + pool.join() + return results + else: + return pool + else: + results = [] + length = get_length(args, kwargs) + for i in tqdm(range(length), desc=desc, disable=not print_progress): + action_args, action_kwargs = get_action_args(length, args, kwargs, i) + async_result = action(*action_args, **action_kwargs) + results.append(async_result) + return results diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9eee268cd8692d885bf093700a5752077b9d7d --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py @@ -0,0 +1,284 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +PCA utilities for feature visualization and dimensionality reduction (video-friendly). +- Support frame-by-frame: transform_frame / transform_video +- Support one-time global PCA fitting and reuse (mean, V3) for stable colors +- Support Procrustes alignment (solving principal component order/sign/rotation jumps) +- Support global fixed or temporal EMA for percentiles (time dimension only, no spatial) +""" + +import numpy as np +import torch + + +def pca_to_rgb_4d_bf16_percentile( + x_np: np.ndarray, + device=None, + q_oversample: int = 6, + clip_percent: float = 10.0, # Percentage to clip from top and bottom (0~49.9) + return_uint8: bool = False, + enable_autocast_bf16: bool = True, +): + """ + Reduce numpy array of shape (49, 27, 36, 3072) to 3D via PCA and visualize as (49, 27, 36, 3). + - PCA uses torch.pca_lowrank (randomized SVD), defaults to GPU. + - Uses CUDA bf16 autocast in computation (if available), + then per-channel percentile clipping and normalization. + - Default removes 5% outliers from top and bottom (adjustable via clip_percent) to + improve visualization contrast. + + Parameters + ---------- + x_np : np.ndarray + Shape must be (49, 27, 36, 3072). dtype recommended float32/float64. + device : str | None + Specify 'cuda' or 'cpu'. Auto-select if None (prefer cuda). + q_oversample : int + Oversampling q for pca_lowrank, must be >= 3. + Slightly larger than target dim (3) is more stable, default 6. + clip_percent : float + Percentage to clip from top and bottom (0~49.9), + e.g. 5.0 means clip lowest 5% and highest 5% per channel. + return_uint8 : bool + True returns uint8(0~255), otherwise returns float32(0~1). + enable_autocast_bf16 : bool + Enable bf16 autocast on CUDA. + + Returns + ------- + np.ndarray + Array of shape (49, 27, 36, 3), float32[0,1] or uint8[0,255]. + """ + assert ( + x_np.ndim == 4 + ) # and x_np.shape[-1] == 3072, f"expect (49,27,36,3072), got {x_np.shape}" + B1, B2, B3, D = x_np.shape + N = B1 * B2 * B3 + + # Device selection + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Convert input to torch, unified float32 + X = torch.from_numpy(x_np.reshape(N, D)).to(device=device, dtype=torch.float32) + + # Parameter and safety checks + k = 3 + q = max(int(q_oversample), k) + clip_percent = float(clip_percent) + if not (0.0 <= clip_percent < 50.0): + raise ValueError( + "clip_percent must be in [0, 50), e.g. 5.0 means clip 5% from top and bottom" + ) + low = clip_percent / 100.0 + high = 1.0 - low + + with torch.no_grad(): + # Zero mean + mean = X.mean(dim=0, keepdim=True) + Xc = X - mean + + # Main computation: PCA + projection, try to use bf16 + # (auto-fallback if operator not supported) + device.startswith("cuda") and enable_autocast_bf16 + U, S, V = torch.pca_lowrank(Xc, q=q, center=False) # V: (D, q) + V3 = V[:, :k] # (3072, 3) + PCs = Xc @ V3 # (N, 3) + + # === Per-channel percentile clipping and normalization to [0,1] === + # Vectorized one-time calculation of low/high percentiles for each channel + qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype) + qvals = torch.quantile(PCs, q=qs, dim=0) # Shape (2, 3) + lo = qvals[0] # (3,) + hi = qvals[1] # (3,) + + # Avoid degenerate case where hi==lo + denom = torch.clamp(hi - lo, min=1e-8) + + # Broadcast clipping + normalization + PCs = torch.clamp(PCs, lo, hi) + PCs = (PCs - lo) / denom # (N, 3) in [0,1] + + # Restore 4D + PCs = PCs.reshape(B1, B2, B3, k) + + # Output + if return_uint8: + out = (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + else: + out = PCs.clamp(0, 1).to(torch.float32).cpu().numpy() + + return out + + +class PCARGBVisualizer: + """ + Stable PCAβ†’RGB for video features shaped (T, H, W, D) or a single frame (H, W, D). + - Global mean/V3 reference for stable colors + - Per-frame PCA with Procrustes alignment to V3_ref (basis_mode='procrustes') + - Percentile normalization with global or EMA stats (time-only, no spatial smoothing) + """ + + def __init__( + self, + device=None, + q_oversample: int = 16, + clip_percent: float = 10.0, + return_uint8: bool = False, + enable_autocast_bf16: bool = True, + basis_mode: str = "procrustes", # 'fixed' | 'procrustes' + percentile_mode: str = "ema", # 'global' | 'ema' + ema_alpha: float = 0.1, + denom_eps: float = 1e-4, + ): + assert 0.0 <= clip_percent < 50.0 + assert basis_mode in ("fixed", "procrustes") + assert percentile_mode in ("global", "ema") + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.q = max(int(q_oversample), 6) + self.clip_percent = float(clip_percent) + self.return_uint8 = return_uint8 + self.enable_autocast_bf16 = enable_autocast_bf16 + self.basis_mode = basis_mode + self.percentile_mode = percentile_mode + self.ema_alpha = float(ema_alpha) + self.denom_eps = float(denom_eps) + + # reference state + self.mean_ref = None # (1, D) + self.V3_ref = None # (D, 3) + self.lo_ref = None # (3,) + self.hi_ref = None # (3,) + + @torch.no_grad() + def fit_reference(self, frames): + """ + Fit global mean/V3 and initialize percentiles from a reference set. + frames: ndarray (T,H,W,D) or list of (H,W,D) + """ + if isinstance(frames, np.ndarray): + if frames.ndim != 4: + raise ValueError("fit_reference expects (T,H,W,D) ndarray.") + T, H, W, D = frames.shape + X = torch.from_numpy(frames.reshape(T * H * W, D)) + else: # list of (H,W,D) + xs = [torch.from_numpy(x.reshape(-1, x.shape[-1])) for x in frames] + D = xs[0].shape[-1] + X = torch.cat(xs, dim=0) + + X = X.to(self.device, dtype=torch.float32) + X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6) + + mean = X.mean(0, keepdim=True) + Xc = X - mean + + U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 8), center=False) + V3 = V[:, :3] # (D,3) + + PCs = Xc @ V3 + low = self.clip_percent / 100.0 + high = 1.0 - low + qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype) + qvals = torch.quantile(PCs, q=qs, dim=0) + lo, hi = qvals[0], qvals[1] + + self.mean_ref = mean + self.V3_ref = V3 + if self.percentile_mode == "global": + self.lo_ref, self.hi_ref = lo, hi + else: + self.lo_ref = lo.clone() + self.hi_ref = hi.clone() + + @torch.no_grad() + def _project_with_stable_colors(self, X: torch.Tensor) -> torch.Tensor: + """ + X: (N,D) where N = H*W + Returns PCs_raw: (N,3) using stable basis (fixed or Procrustes-aligned) + """ + assert self.mean_ref is not None and self.V3_ref is not None, "Call fit_reference() first." + X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6) + Xc = X - self.mean_ref + + if self.basis_mode == "fixed": + V3_used = self.V3_ref + else: + U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 6), center=False) + V3 = V[:, :3] # (D,3) + M = V3.T @ self.V3_ref + Uo, So, Vh = torch.linalg.svd(M) + R = Uo @ Vh + V3_used = V3 @ R + # Optional polarity fix via anchor + a = self.V3_ref.mean(0, keepdim=True) + sign = torch.sign((V3_used * a).sum(0, keepdim=True)).clamp(min=-1) + V3_used = V3_used * sign + + return Xc @ V3_used + + @torch.no_grad() + def _normalize_rgb(self, PCs_raw: torch.Tensor) -> torch.Tensor: + assert self.lo_ref is not None and self.hi_ref is not None + if self.percentile_mode == "global": + lo, hi = self.lo_ref, self.hi_ref + else: + low = self.clip_percent / 100.0 + high = 1.0 - low + qs = torch.tensor([low, high], device=PCs_raw.device, dtype=PCs_raw.dtype) + qvals = torch.quantile(PCs_raw, q=qs, dim=0) + lo_now, hi_now = qvals[0], qvals[1] + a = self.ema_alpha + self.lo_ref = (1 - a) * self.lo_ref + a * lo_now + self.hi_ref = (1 - a) * self.hi_ref + a * hi_now + lo, hi = self.lo_ref, self.hi_ref + + denom = torch.clamp(hi - lo, min=self.denom_eps) + PCs = torch.clamp(PCs_raw, lo, hi) + PCs = (PCs - lo) / denom + return PCs.clamp_(0, 1) + + @torch.no_grad() + def transform_frame(self, frame: np.ndarray) -> np.ndarray: + """ + frame: (H,W,D) -> (H,W,3) + """ + if frame.ndim != 3: + raise ValueError("transform_frame expects (H,W,D).") + H, W, D = frame.shape + X = torch.from_numpy(frame.reshape(H * W, D)).to(self.device, dtype=torch.float32) + PCs_raw = self._project_with_stable_colors(X) + PCs = self._normalize_rgb(PCs_raw).reshape(H, W, 3) + if self.return_uint8: + return (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + return PCs.to(torch.float32).cpu().numpy() + + @torch.no_grad() + def transform_video(self, frames) -> np.ndarray: + """ + frames: (T,H,W,D) or list of (H,W,D) + returns: (T,H,W,3) + """ + outs = [] + if isinstance(frames, np.ndarray): + if frames.ndim != 4: + raise ValueError("transform_video expects (T,H,W,D).") + T, H, W, D = frames.shape + for t in range(T): + outs.append(self.transform_frame(frames[t])) + else: + for f in frames: + outs.append(self.transform_frame(f)) + return np.stack(outs, axis=0) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py new file mode 100644 index 0000000000000000000000000000000000000000..695d07fc1210e3cc3614c9b22c56d765b1106500 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List +import numpy as np +import torch +from evo.core.trajectory import PosePath3D + +from depth_anything_3.utils.geometry import affine_inverse, affine_inverse_np + + +def batch_apply_alignment_to_enc( + rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, enc_list: List[torch.Tensor] +): + pass + + +def batch_apply_alignment_to_ext( + rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, ext: torch.Tensor +): + device, _ = ext.device, ext.dtype + if ext.shape[-2:] == (3, 4): + pad = torch.zeros((*ext.shape[:-2], 4, 4), dtype=ext.dtype, device=device) + pad[..., :3, :4] = ext + pad[..., 3, 3] = 1.0 + ext = pad + pose_est = affine_inverse(ext) + pose_new_align_rot = rots[:, None] @ pose_est[..., :3, :3] + pose_new_align_trans = ( + scales[:, None, None] * (rots[:, None] @ pose_est[..., :3, 3:])[..., 0] + trans[:, None] + ) + pose_new_align = torch.zeros_like(ext) + pose_new_align[..., :3, :3] = pose_new_align_rot + pose_new_align[..., :3, 3] = pose_new_align_trans + pose_new_align[..., 3, 3] = 1.0 + return affine_inverse(pose_new_align)[:, :3] + + +def batch_align_poses_umeyama(ext_ref: torch.Tensor, ext_est: torch.Tensor): + device, dtype = ext_ref.device, ext_ref.dtype + assert ext_ref.dtype in [torch.float32, torch.float64] + assert ext_est.dtype in [torch.float32, torch.float64] + assert ext_ref.requires_grad is False + assert ext_est.requires_grad is False + rots, trans, scales = [], [], [] + for b in range(ext_ref.shape[0]): + r, t, s = align_poses_umeyama(ext_ref[b].cpu().numpy(), ext_est[b].cpu().numpy()) + rots.append(torch.from_numpy(r).to(device=device, dtype=dtype)) + trans.append(torch.from_numpy(t).to(device=device, dtype=dtype)) + scales.append(torch.tensor(s, device=device, dtype=dtype)) + return torch.stack(rots), torch.stack(trans), torch.stack(scales) + + +# Dependencies: affine_inverse_np, PosePath3D (maintain consistency with your existing project) + + +def _to44(ext): + if ext.shape[1] == 3: + out = np.eye(4)[None].repeat(len(ext), 0) + out[:, :3, :4] = ext + return out + return ext + + +def _poses_from_ext(ext_ref, ext_est): + ext_ref = _to44(ext_ref) + ext_est = _to44(ext_est) + pose_ref = affine_inverse_np(ext_ref) + pose_est = affine_inverse_np(ext_est) + return pose_ref, pose_est + + +def _umeyama_sim3_from_paths(pose_ref, pose_est): + path_ref = PosePath3D(poses_se3=pose_ref.copy()) + path_est = PosePath3D(poses_se3=pose_est.copy()) + r, t, s = path_est.align(path_ref, correct_scale=True) + pose_est_aligned = np.stack(path_est.poses_se3) + return r, t, s, pose_est_aligned + + +def _apply_sim3_to_poses(poses, r, t, s): + out = poses.copy() + Ri = poses[:, :3, :3] + ti = poses[:, :3, 3] + out[:, :3, :3] = r @ Ri + out[:, :3, 3] = (r @ (s * ti.T)).T + t + return out + + +def _median_nn_thresh(pose_ref, pose_est_aligned): + P_ref = pose_ref[:, :3, 3] + P_est = pose_est_aligned[:, :3, 3] + dists = [] + for p in P_est: + dd = np.linalg.norm(P_ref - p[None, :], axis=1) + dists.append(dd.min()) + return float(np.median(dists)) if dists else 0.0 + + +def _ransac_align_sim3( + pose_ref, pose_est, sub_n=None, inlier_thresh=None, max_iters=10, random_state=None +): + rng = np.random.default_rng(random_state) + N = pose_ref.shape[0] + idx_all = np.arange(N) + if sub_n is None: + sub_n = max(3, (N + 1) // 2) + else: + sub_n = max(3, min(sub_n, N)) + + # Pre-alignment + default threshold + r0, t0, s0, pose_est0 = _umeyama_sim3_from_paths(pose_ref, pose_est) + if inlier_thresh is None: + inlier_thresh = _median_nn_thresh(pose_ref, pose_est0) + + P_ref_all = pose_ref[:, :3, 3] + + best_model = (r0, t0, s0) + best_inliers = None + best_score = (-1, np.inf) # (num_inliers, mean_err) + + for _ in range(max_iters): + sample = rng.choice(idx_all, size=sub_n, replace=False) + try: + r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[sample], pose_est[sample]) + except Exception: + continue + pose_h = _apply_sim3_to_poses(pose_est, r, t, s) + P_h = pose_h[:, :3, 3] + errs = np.linalg.norm(P_h - P_ref_all, axis=1) # Match by same index + inliers = errs <= inlier_thresh + k = int(inliers.sum()) + mean_err = float(errs[inliers].mean()) if k > 0 else np.inf + if (k > best_score[0]) or (k == best_score[0] and mean_err < best_score[1]): + best_score = (k, mean_err) + best_model = (r, t, s) + best_inliers = inliers + + # Fit again with best inliers + if best_inliers is not None and best_inliers.sum() >= 3: + r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[best_inliers], pose_est[best_inliers]) + else: + r, t, s = best_model + return r, t, s + + +def align_poses_umeyama( + ext_ref: np.ndarray, + ext_est: np.ndarray, + return_aligned=False, + ransac=False, + sub_n=None, + inlier_thresh=None, + ransac_max_iters=10, + random_state=None, +): + """ + Align estimated trajectory to reference using Umeyama Sim(3). + Default no RANSAC; if ransac=True, use RANSAC (max iterations default 10). + - sub_n defaults to half the number of frames (rounded up, at least 3) + - inlier_thresh defaults to median of "distance from each estimated pose to + nearest reference pose after pre-alignment" + Returns rotation (3x3), translation (3,), scale; optionally returns aligned extrinsics (4x4). + """ + pose_ref, pose_est = _poses_from_ext(ext_ref, ext_est) + + if not ransac: + r, t, s, pose_est_aligned = _umeyama_sim3_from_paths(pose_ref, pose_est) + else: + r, t, s = _ransac_align_sim3( + pose_ref, + pose_est, + sub_n=sub_n, + inlier_thresh=inlier_thresh, + max_iters=ransac_max_iters, + random_state=random_state, + ) + pose_est_aligned = _apply_sim3_to_poses(pose_est, r, t, s) + + if return_aligned: + ext_est_aligned = affine_inverse_np(pose_est_aligned) + return r, t, s, ext_est_aligned + return r, t, s + + +# def align_poses_umeyama(ext_ref: np.ndarray, ext_est: np.ndarray, return_aligned=False): +# """ +# Align estimated trajectory to reference trajectory using Umeyama Sim(3) +# alignment (via evo PosePath3D). # noqa +# Returns rotation, translation, and scale. +# """ +# # If input extrinsics are 3x4, convert to 4x4 by padding +# if ext_ref.shape[1] == 3: +# ext_ref_ = np.eye(4)[None].repeat(len(ext_ref), 0) +# ext_ref_[:, :3] = ext_ref +# ext_ref = ext_ref_ +# if ext_est.shape[1] == 3: +# ext_est_ = np.eye(4)[None].repeat(len(ext_est), 0) +# ext_est_[:, :3] = ext_est +# ext_est = ext_est_ + +# # Convert to camera poses (inverse extrinsics) +# pose_ref = affine_inverse_np(ext_ref) +# pose_est = affine_inverse_np(ext_est) + +# # Create evo PosePath3D objects +# path_ref = PosePath3D(poses_se3=pose_ref) +# path_est = PosePath3D(poses_se3=pose_est) +# r, t, s = path_est.align(path_ref, correct_scale=True) +# if return_aligned: +# return r, t, s, affine_inverse_np(np.stack(path_est.poses_se3)) +# else: +# return r, t, s + + +def apply_umeyama_alignment_to_ext( + rot: np.ndarray, # (3,3) + trans: np.ndarray, # (3,) or (1,3) + scale: float, + ext_est: np.ndarray, # (...,4,4) or (...,3,4) +) -> np.ndarray: + """ + Apply Sim(3) (R, t, s) to a batch of world-to-camera extrinsics ext_est. + Returns the aligned extrinsics, with the same shape as input. + """ + + # Allow 3x4 extrinsics: pad to 4x4 + if ext_est.shape[-2:] == (3, 4): + pad = np.zeros((*ext_est.shape[:-2], 4, 4), dtype=ext_est.dtype) + pad[..., :3, :4] = ext_est + pad[..., 3, 3] = 1.0 + ext_est = pad + + # Convert world-to-camera to camera-to-world + pose_est = affine_inverse_np(ext_est) # (...,4,4) + R_e = pose_est[..., :3, :3] # (...,3,3) + t_e = pose_est[..., :3, 3] # (...,3) + + # Apply Sim(3) transformation + R_a = np.einsum("ij,...jk->...ik", rot, R_e) # (...,3,3) + t_a = scale * np.einsum("ij,...j->...i", rot, t_e) + trans # (...,3) + + # Assemble the transformed pose + pose_a = np.zeros_like(pose_est) + pose_a[..., :3, :3] = R_a + pose_a[..., :3, 3] = t_a + pose_a[..., 3, 3] = 1.0 + + # Convert back to world-to-camera + return affine_inverse_np(pose_a) + + +def transform_points_sim3(points, rot, trans, scale, inverse=False): + """ + Sim(3) transform point cloud + points: (N, 3) + rot: (3, 3) + trans: (3,) or (1, 3) + scale: float + inverse: Whether to do inverse transform (ref->est) + Returns: (N, 3) + """ + if not inverse: + # Forward: est -> ref + return scale * (points @ rot.T) + trans + else: + # Inverse: ref -> est + return ((points - trans) @ rot) / scale + + +def _rand_rot(): + u1, u2, u3 = np.random.rand(3) + q = np.array( + [ + np.sqrt(1 - u1) * np.sin(2 * np.math.pi * u2), + np.sqrt(1 - u1) * np.cos(2 * np.math.pi * u2), + np.sqrt(u1) * np.sin(2 * np.math.pi * u3), + np.sqrt(u1) * np.cos(2 * np.math.pi * u3), + ] + ) + w, x, y, z = q + return np.array( + [ + [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], + [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], + [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], + ] + ) + + +def _rand_pose(): + R, t = _rand_rot(), np.random.randn(3) + P = np.eye(4) + P[:3, :3] = R + P[:3, 3] = t + return P + + +if __name__ == "__main__": + np.random.seed(42) + # 1. Randomly generate reference trajectory and Sim(3) + N = 8 + pose_ref = np.stack([_rand_pose() for _ in range(N)]) # (N,4,4) camβ†’world + rot_gt = _rand_rot() + scale_gt = 2.3 + trans_gt = np.random.randn(3) + # 2. Generate estimated trajectory (apply Sim(3)) + pose_est = np.zeros_like(pose_ref) + for i in range(N): + R = pose_ref[i][:3, :3] + t = pose_ref[i][:3, 3] + pose_est[i][:3, :3] = rot_gt @ R + pose_est[i][:3, 3] = scale_gt * (rot_gt @ t) + trans_gt + pose_est[i][3, 3] = 1.0 + # 3. Get extrinsics (world->cam) + ext_ref = affine_inverse_np(pose_ref) + ext_est = affine_inverse_np(pose_est) + # 4. Use umeyama alignment, estimate Sim(3) + r_est, t_est, s_est = align_poses_umeyama(ext_ref, ext_est) + print("GT scale:", scale_gt, "Estimated:", s_est) + print("GT trans:", trans_gt, "Estimated:", t_est) + print("GT rot:\n", rot_gt, "\nEstimated:\n", r_est) + # 5. Random point cloud, in ref frame + num_points = 100 + points_ref = np.random.randn(num_points, 3) + # 6. Use GT Sim(3) inverse transform to est frame + points_est = transform_points_sim3(points_ref, rot_gt, trans_gt, scale_gt, inverse=True) + # 7. Use estimated Sim(3) forward transform back to ref frame + points_ref_recovered = transform_points_sim3(points_est, r_est, t_est, s_est, inverse=False) + # 8. Check error + err = np.abs(points_ref_recovered - points_ref) + print("Point cloud sim3 transform error (mean abs):", err.mean()) + print("Point cloud sim3 transform error (max abs):", err.max()) + assert err.mean() < 1e-6, "Mean sim3 transform error too large!" + assert err.max() < 1e-5, "Max sim3 transform error too large!" + print("Sim(3) point cloud transform & alignment test passed!") diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4cf197f609665e46f14e6b3e8b6657efc5660c --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py @@ -0,0 +1,585 @@ +# Copyright (c), ETH Zurich and UNC Chapel Hill. +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# All rights reserved. +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 11/05/2025 +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import collections +import os +import struct +import numpy as np + +CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = {camera_model.model_id: camera_model for camera_model in CAMERA_MODELS} +CAMERA_MODEL_NAMES = {camera_model.model_name: camera_model for camera_model in CAMERA_MODELS} + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, + model=model, + width=width, + height=height, + params=params, + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, + num_bytes=8 * num_params, + format_char_sequence="d" * num_params, + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + f"# Number of cameras: {len(cameras)}\n" + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [ + tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + binary_image_name = b"" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + binary_image_name += current_char + current_char = read_next_bytes(fid, 1, "c")[0] + image_name = binary_image_name.decode("utf-8") + num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [ + tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [ + img.id, + *img.qvec, + *img.tvec, + img.camera_id, + img.name, + ] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path) as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models") + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="output model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, + images, + points3D, + path=args.output_model, + ext=args.output_format, + ) + + +if __name__ == "__main__": + main() diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..7db16d525e0417110d87aa5e621b792d8bd95596 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from addict import Dict + + +class Registry(Dict[str, Any]): + def __init__(self): + super().__init__() + self._map = Dict({}) + + def register(self, name=None): + def decorator(cls): + key = name or cls.__name__ + self._map[key] = cls + return cls + + return decorator + + def get(self, name): + return self._map[name] + + def all(self): + return self._map diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..1a1a4ca8204eb8d858351afb253e41e77dfa2f74 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from math import isqrt +import torch +from einops import einsum + +try: + from e3nn.o3 import matrix_to_angles, wigner_D +except ImportError: + from depth_anything_3.utils.logger import logger + + logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff") + + +def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor: + if M.shape[-2:] != (3, 3): + raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).") + + # 1. Compute SVD + U, S, Vh = torch.linalg.svd(M) + V = Vh.mH + + # 2. Handle reflection case (det = -1) + det_U = torch.det(U) + det_V = torch.det(V) + is_reflection = (det_U * det_V) < 0 + correction_sign = torch.where( + is_reflection[..., None], + torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype), + torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype), + ) + correction_matrix = torch.diag_embed(correction_sign) + U_corrected = U @ correction_matrix + R_so3_initial = U_corrected @ V.transpose(-2, -1) + + # 3. Explicitly ensure determinant is 1 (or extremely close) + current_det = torch.det(R_so3_initial) + det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None] + R_so3_final = R_so3_initial * det_correction_factor + + return R_so3_final + + +def rotate_sh( + sh_coefficients: torch.Tensor, # "*#batch n" + rotations: torch.Tensor, # "*#batch 3 3" +) -> torch.Tensor: # "*batch n" + # https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653 + device = sh_coefficients.device + dtype = sh_coefficients.dtype + + *_, n = sh_coefficients.shape + + with torch.autocast(device_type=rotations.device.type, enabled=False): + rotations_float32 = rotations.to(torch.float32) + + # switch axes: yzx -> xyz + P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32) + permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P + + # ensure rotation has det == 1 in float32 type + permuted_rotations_so3 = project_to_so3_strict(permuted_rotations) + + alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3) + result = [] + for degree in range(isqrt(n)): + with torch.device(device): + sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype) + sh_rotated = einsum( + sh_rotations, + sh_coefficients[..., degree**2 : (degree + 1) ** 2], + "... i j, ... j -> ... i", + ) + result.append(sh_rotated) + + return torch.cat(result, dim=-1) diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd32bddf00e5461f674525e73653409270c0227 --- /dev/null +++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import matplotlib +import numpy as np +import torch +from einops import rearrange + +from depth_anything_3.utils.logger import logger + + +def visualize_depth( + depth: np.ndarray, + depth_min=None, + depth_max=None, + percentile=2, + ret_minmax=False, + ret_type=np.uint8, + cmap="Spectral", +): + """ + Visualize a depth map using a colormap. + + Args: + depth: Input depth map array + depth_min: Minimum depth value for normalization. If None, uses percentile + depth_max: Maximum depth value for normalization. If None, uses percentile + percentile: Percentile for min/max computation if not provided + ret_minmax: Whether to return min/max depth values + ret_type: Return array type (uint8 or float) + cmap: Matplotlib colormap name to use + + Returns: + Colored depth visualization as numpy array + If ret_minmax=True, also returns depth_min and depth_max + """ + depth = depth.copy() + depth.copy() + valid_mask = depth > 0 + depth[valid_mask] = 1 / depth[valid_mask] + if depth_min is None: + if valid_mask.sum() <= 10: + depth_min = 0 + else: + depth_min = np.percentile(depth[valid_mask], percentile) + if depth_max is None: + if valid_mask.sum() <= 10: + depth_max = 0 + else: + depth_max = np.percentile(depth[valid_mask], 100 - percentile) + if depth_min == depth_max: + depth_min = depth_min - 1e-6 + depth_max = depth_max + 1e-6 + cm = matplotlib.colormaps[cmap] + depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1) + depth = 1 - depth + img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1 + if ret_type == np.uint8: + img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8) + elif ret_type == np.float32 or ret_type == np.float64: + img_colored_np = img_colored_np[0] + else: + raise ValueError(f"Invalid return type: {ret_type}") + if ret_minmax: + return img_colored_np, depth_min, depth_max + else: + return img_colored_np + + +# GS video rendering visulization function, since it operates in Tensor space... + + +def vis_depth_map_tensor( + result: torch.Tensor, # "*batch height width" + color_map: str = "Spectral", +) -> torch.Tensor: # "*batch 3 height with" + """ + Color-map the depth map. + """ + far = result.reshape(-1)[:16_000_000].float().quantile(0.99).log().to(result) + try: + near = result[result > 0][:16_000_000].float().quantile(0.01).log().to(result) + except (RuntimeError, ValueError) as e: + logger.error(f"No valid depth values found. Reason: {e}") + near = torch.zeros_like(far) + result = result.log() + result = (result - near) / (far - near) + return apply_color_map_to_image(result, color_map) + + +def apply_color_map( + x: torch.Tensor, # " *batch" + color_map: str = "inferno", +) -> torch.Tensor: # "*batch 3" + cmap = matplotlib.cm.get_cmap(color_map) + + # Convert to NumPy so that Matplotlib color maps can be used. + mapped = cmap(x.float().detach().clip(min=0, max=1).cpu().numpy())[..., :3] + + # Convert back to the original format. + return torch.tensor(mapped, device=x.device, dtype=torch.float32) + + +def apply_color_map_to_image( + image: torch.Tensor, # "*batch height width" + color_map: str = "inferno", +) -> torch.Tensor: # "*batch 3 height with" + image = apply_color_map(image, color_map) + return rearrange(image, "... h w c -> ... c h w") diff --git a/README.md b/README.md index c0bef943eb50179a725be8ffb1666694d21f8849..c0f703a2800bf56b28bb69a60ff2fef4f865e3f0 100644 --- a/README.md +++ b/README.md @@ -4,31 +4,34 @@ emoji: πŸ‘€ colorFrom: indigo colorTo: indigo sdk: gradio -sdk_version: 5.49.1 +sdk_version: 6.0.0 app_file: app.py pinned: false --- # Depth Estimation Comparison Demo -A ZeroGPU-friendly Gradio interface for comparing **Depth Anything v1**, **Depth Anything v2**, and **Pixel-Perfect Depth (PPD)** on the same image. Switch between side-by-side layouts, a slider overlay, or single-model inspection to understand how different pipelines perceive scene geometry. +A Gradio interface for comparing **Depth Anything v1**, **Depth Anything v2**, **Depth Anything v3 (AnySize)**, and **Pixel-Perfect Depth (PPD)** on the same image. Switch between side-by-side layouts, a slider overlay, single-model inspection, or a dedicated v3 tab to understand how different pipelines perceive scene geometry. Two entrypoints are provided: + +- `app_local.py` – full-featured local runner with minimal memory constraints. +- `app.py` – ZeroGPU-aware build tuned for HuggingFace Spaces with aggressive cache management. ## πŸš€ Highlights -- **Three interactive views**: draggable slider, labeled side-by-side comparison, and original vs depth for any single model. -- **Multi-family depth models**: run ViT variants from Depth Anything v1/v2 alongside Pixel-Perfect Depth with MoGe metric alignment. -- **ZeroGPU aware**: on-demand loading, model cache clearing, and torch CUDA cleanup keep GPU usage inside HuggingFace Spaces limits. -- **Curated examples**: reusable demo images sourced from each model family plus local assets to quickly validate behaviour. +- **Four interactive experiences**: draggable slider, labeled side-by-side comparison, original-vs-depth slider, and a Depth Anything v3 tab with RGB vs depth visualization + metadata. +- **Multi-family depth models**: run ViT variants from Depth Anything v1/v2/v3 alongside Pixel-Perfect Depth with MoGe metric alignment. +- **ZeroGPU aware**: `app.py` performs on-demand loading, cache clearing, and CUDA cleanup to stay within HuggingFace Spaces limits, while `app_local.py` keeps models warm for faster iteration. +- **Curated examples**: reusable demo images sourced from each model family (`assets/examples`, `Depth-Anything*/assets/examples`, `Depth-Anything-3-anysize/assets/examples`, `Pixel-Perfect-Depth/assets/examples`). ## πŸ” Supported Pipelines - **Depth Anything v1** (`LiheYoung/depth_anything_*`): ViT-S/B/L with fast transformer backbones and colorized outputs via `Spectral_r` colormap. -- **Depth Anything v2** (`Depth-Anything-V2/checkpoints/*.pth`): ViT-Small/Base/Large with HF Hub fallback, configurable feature channels, and improved edge handling. +- **Depth Anything v2** (`Depth-Anything-V2/checkpoints/*.pth` or HF Hub mirrors): ViT-Small/Base/Large with configurable feature channels and improved edge handling. +- **Depth Anything v3 (AnySize)** (`depth-anything/DA3*` via bundled AnySize fork): Nested, giant, large, base, small, mono, and metric variants with native-resolution inference and automatic padding/cropping. - **Pixel-Perfect Depth**: Diffusion-based relative depth refined by the **MoGe** metric surface model and RANSAC alignment to recover metric depth; customizable denoising steps. ## πŸ–₯️ App Experience -- **Slider Comparison**: drag between two predictions with automatically labeled overlays. +- **Slider Comparison**: drag between any two predictions with automatically labeled overlays. - **Method Comparison**: view models side-by-side with synchronized layout and captions rendered in OpenCV. - **Single Model**: inspect the RGB input versus one model output using the Gradio `ImageSlider` component. -- **Example Gallery**: natural-number sorting across `assets/examples`, `Depth-Anything/assets/examples`, `Depth-Anything-V2/assets/examples`, and `Pixel-Perfect-Depth/assets/examples`. ## πŸ“¦ Installation & Setup @@ -42,44 +45,57 @@ A ZeroGPU-friendly Gradio interface for comparing **Depth Anything v1**, **Depth ```bash pip install -r requirements.txt ``` -3. **Model assets**: +3. **Install the AnySize fork** (required for Depth Anything v3 tab): + ```bash + pip install -e Depth-Anything-3-anysize/.[all] + ``` +4. **Model assets**: - Depth Anything v1 checkpoints stream automatically from the HuggingFace Hub. - Download Depth Anything v2 weights into `Depth-Anything-V2/checkpoints/` if they are not already present (`depth_anything_v2_vits.pth`, `depth_anything_v2_vitb.pth`, `depth_anything_v2_vitl.pth`). + - Depth Anything v3 models download via the bundled AnySize API from `depth-anything/*` repositories at inference time; no manual checkpoints required. - Pixel-Perfect Depth pulls the diffusion checkpoint (`ppd.pth`) from `gangweix/Pixel-Perfect-Depth` on first use and loads MoGe weights (`Ruicheng/moge-2-vitl-normal`). -4. **Run the app**: +5. **Run the app**: ```bash - python app_local.py # Local UI with live reload tweaks - python app.py # ZeroGPU-ready launch script + python app_local.py # Local UI with v3 tab and warm caches + python app.py # ZeroGPU-ready launch script (loads models on demand) ``` ### HuggingFace Spaces (ZeroGPU) 1. Push the repository contents to a Gradio Space. 2. Select the **ZeroGPU** hardware preset. -3. The app will download required checkpoints on demand and aggressively free memory after each inference via `clear_model_cache()`. +3. The app downloads required checkpoints (Depth Anything v1/v2/v3, PPD, MoGe) on demand and aggressively frees memory via `clear_model_cache()` between requests. ## πŸ“ Project Structure ``` Depth-Estimation-Compare-demo/ -β”œβ”€β”€ app.py # ZeroGPU deployment entrypoint -β”œβ”€β”€ app_local.py # Local-friendly launch script -β”œβ”€β”€ requirements.txt # Python dependencies (Gradio, Torch, PPD stack) +β”œβ”€β”€ app.py # ZeroGPU deployment entrypoint (includes v3 tab) +β”œβ”€β”€ app_local.py # Local-friendly launch script (full feature set) +β”œβ”€β”€ requirements.txt # Python dependencies (Gradio, Torch, PPD stack) β”œβ”€β”€ assets/ -β”‚ └── examples/ # Shared demo imagery -β”œβ”€β”€ Depth-Anything/ # Depth Anything v1 implementation + utilities -β”œβ”€β”€ Depth-Anything-V2/ # Depth Anything v2 implementation & checkpoints -β”œβ”€β”€ Pixel-Perfect-Depth/ # Pixel-Perfect Depth diffusion + MoGe helpers -└── README.md # You are here +β”‚ └── examples/ # Shared demo imagery +β”œβ”€β”€ Depth-Anything/ # Depth Anything v1 implementation + utilities +β”œβ”€β”€ Depth-Anything-V2/ # Depth Anything v2 implementation & checkpoints +β”œβ”€β”€ Depth-Anything-3-anysize/ # Bundled AnySize fork powering Depth Anything v3 tab +β”‚ β”œβ”€β”€ app.py # Standalone AnySize Gradio demo (optional) +β”‚ β”œβ”€β”€ depth3_anysize.py # Scripted inference example +β”‚ β”œβ”€β”€ pyproject.toml # Editable install metadata +β”‚ β”œβ”€β”€ requirements.txt # AnySize-specific dependencies +β”‚ └── src/depth_anything_3/ # AnySize API, configs, and model code +β”œβ”€β”€ Pixel-Perfect-Depth/ # Pixel-Perfect Depth diffusion + MoGe helpers +└── README.md # You are here ``` ## βš™οΈ Configuration Notes -- Model dropdown labels come from `V1_MODEL_CONFIGS`, `V2_MODEL_CONFIGS`, and the PPD entry in `app.py`. -- `clear_model_cache()` resets every model and flushes CUDA to respect ZeroGPU constraints. +- Model dropdown labels come from `V1_MODEL_CONFIGS`, `V2_MODEL_CONFIGS`, and `DA3_MODEL_SOURCES` plus the PPD entry in both apps. +- `clear_model_cache()` resets every model family (v1/v2/v3/PPD) and flushes CUDA to respect ZeroGPU constraints in `app.py`. +- Depth Anything v3 inference leverages the AnySize API (`process_res=None`, `process_res_method="keep"`) to preserve native resolution and returns processed RGB/depth pairs. - Pixel-Perfect Depth inference aligns relative depth to metric scale through `recover_metric_depth_ransac()` for consistent visualization. - Depth visualizations use a normalized `Spectral_r` colormap; PPD uses a dedicated matplotlib colormap for metric maps. ## πŸ“Š Performance Expectations - **Depth Anything v1**: ViT-S ~1–2 s, ViT-B ~2–4 s, ViT-L ~4–8 s (image dependent). - **Depth Anything v2**: similar to v1 with improved sharpness; HF downloads add one-time startup overhead. +- **Depth Anything v3**: nested/giant models are heavier (expect longer cold starts), while base/small options are close to v2 latency when running at native resolution. - **Pixel-Perfect Depth**: diffusion + metric refinement typically takes longer (10–20 denoise steps) but returns metrically-aligned depth suitable for downstream 3D tasks. ## 🎯 Usage Tips @@ -95,6 +111,7 @@ Enhancements are welcomeβ€”new model backends, visualization modes, or memory op - [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) - [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) - [MoGe](https://huggingface.co/Ruicheng/moge-2-vitl-normal) +- [Depth Anything 3 AnySize Fork](https://github.com/ByteDance-Seed/Depth-Anything-3) (see bundled `Depth-Anything-3-anysize` directory) ## πŸ“„ License - Depth Anything v1: MIT License @@ -104,4 +121,4 @@ Enhancements are welcomeβ€”new model backends, visualization modes, or memory op --- -Built as a hands-on playground for exploring modern monocular depth estimators. Adjust tabs, compare outputs, and plug results into your 3D workflows. +Built as a hands-on playground for exploring modern monocular depth estimators. Adjust tabs, compare outputs, and plug results into your 3D workflows. \ No newline at end of file diff --git a/app.py b/app.py index 2b70e7e5372d9cf7cbcbbc707dee4f90bcd017f0..116603faaf24f25d6b9c0fe202f94fcc45d6d66b 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,7 @@ """ Depth Estimation Comparison Demo (ZeroGPU) -Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider using Gradio. +Compare Depth Anything v1, Depth Anything v2, Depth Anything v3, and Pixel-Perfect Depth side-by-side or with a slider using Gradio. Optimized for HuggingFace Spaces with ZeroGPU support. """ @@ -9,16 +9,19 @@ import os import sys import logging import gc -from typing import Optional, Tuple, List +import inspect +from typing import Optional, Tuple, List, Dict import numpy as np import cv2 import gradio as gr from huggingface_hub import hf_hub_download import spaces +from PIL import Image # Import v1 and v2 model code sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-V2")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-3-anysize", "src")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth")) # v1 imports @@ -33,6 +36,10 @@ from depth_anything_v2.dpt import DepthAnythingV2 import matplotlib +# Depth Anything v3 imports +from depth_anything_3.api import DepthAnything3 +from depth_anything_3.utils.visualize import visualize_depth + # Pixel-Perfect Depth imports from ppd.utils.set_seed import set_seed from ppd.utils.align_depth_func import recover_metric_depth_ransac @@ -82,9 +89,41 @@ V2_MODEL_CONFIGS = { } } +DA3_MODEL_SOURCES: Dict[str, Dict[str, str]] = { + "nested_giant_large": { + "display_name": "Depth Anything v3 Nested Giant Large", + "repo_id": "depth-anything/DA3NESTED-GIANT-LARGE", + }, + "giant": { + "display_name": "Depth Anything v3 Giant", + "repo_id": "depth-anything/DA3-GIANT", + }, + "large": { + "display_name": "Depth Anything v3 Large", + "repo_id": "depth-anything/DA3-LARGE", + }, + "base": { + "display_name": "Depth Anything v3 Base", + "repo_id": "depth-anything/DA3-BASE", + }, + "small": { + "display_name": "Depth Anything v3 Small", + "repo_id": "depth-anything/DA3-SMALL", + }, + "metric_large": { + "display_name": "Depth Anything v3 Metric Large", + "repo_id": "depth-anything/DA3METRIC-LARGE", + }, + "mono_large": { + "display_name": "Depth Anything v3 Mono Large", + "repo_id": "depth-anything/DA3MONO-LARGE", + }, +} + # Model cache - cleared after each inference for ZeroGPU _v1_models = {} _v2_models = {} +_da3_models: Dict[str, DepthAnything3] = {} _ppd_model: Optional[PixelPerfectDepth] = None _moge_model: Optional[MoGeModel] = None @@ -160,15 +199,83 @@ def load_v2_model(key: str): _v2_models[key] = model return model + +def load_da3_model(key: str) -> DepthAnything3: + if key in _da3_models: + return _da3_models[key] + + clear_model_cache() + + repo_id = DA3_MODEL_SOURCES[key]["repo_id"] + model = DepthAnything3.from_pretrained(repo_id) + model = model.to(device=TORCH_DEVICE) + model.eval() + _da3_models[key] = model + return model + + +def _prep_da3_image(image: np.ndarray) -> np.ndarray: + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + if image.dtype != np.uint8: + image = np.clip(image, 0, 255).astype(np.uint8) + return image + + +def run_da3_inference(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, str, str]: + model = load_da3_model(model_key) + if image.ndim == 2: + rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + else: + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + rgb = _prep_da3_image(rgb) + prediction = model.inference( + image=[Image.fromarray(rgb)], + process_res=None, + process_res_method="keep", + ) + + depth_map = prediction.depth[0] + depth_vis = visualize_depth(depth_map, cmap="Spectral") + processed_rgb = ( + prediction.processed_images[0] + if getattr(prediction, "processed_images", None) is not None + else rgb + ) + processed_rgb = np.clip(processed_rgb, 0, 255).astype(np.uint8) + + target_h, target_w = image.shape[:2] + if depth_vis.shape[:2] != (target_h, target_w): + depth_vis = cv2.resize(depth_vis, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + if processed_rgb.shape[:2] != (target_h, target_w): + processed_rgb = cv2.resize(processed_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + + label = DA3_MODEL_SOURCES[model_key]["display_name"] + info_lines = [ + f"**Model:** `{label}`", + f"**Repo:** `{DA3_MODEL_SOURCES[model_key]['repo_id']}`", + f"**Device:** `{str(TORCH_DEVICE)}`", + f"**Depth shape:** `{tuple(prediction.depth.shape)}`", + ] + if getattr(prediction, "extrinsics", None) is not None: + info_lines.append(f"**Extrinsics shape:** `{prediction.extrinsics.shape}`") + if getattr(prediction, "intrinsics", None) is not None: + info_lines.append(f"**Intrinsics shape:** `{prediction.intrinsics.shape}`") + + return depth_vis, processed_rgb, "\n".join(info_lines), label + def clear_model_cache(): """Clear model cache to free GPU memory for ZeroGPU""" - global _v1_models, _v2_models, _ppd_model, _moge_model + global _v1_models, _v2_models, _da3_models, _ppd_model, _moge_model for model in _v1_models.values(): del model for model in _v2_models.values(): del model + for model in _da3_models.values(): + del model _v1_models.clear() _v2_models.clear() + _da3_models.clear() _ppd_model = None _moge_model = None gc.collect() @@ -266,6 +373,8 @@ def get_model_choices() -> List[Tuple[str, str]]: choices.append((v['display_name'], f'v1_{k}')) for k, v in V2_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v2_{k}')) + for k, v in DA3_MODEL_SOURCES.items(): + choices.append((v['display_name'], f'da3_{k}')) choices.append(("Pixel-Perfect Depth", "ppd")) return choices @@ -287,6 +396,10 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]: label = V2_MODEL_CONFIGS[key]['display_name'] colored = colorize_depth(depth) return colored, label + elif model_key.startswith('da3_'): + key = model_key[4:] + depth_vis, _, _, label = run_da3_inference(key, image) + return depth_vis, label elif model_key == 'ppd': clear_model_cache() _, colored = pixel_perfect_depth_inference(image) @@ -429,6 +542,37 @@ def single_inference(image, model: str, progress=gr.Progress()): # Clean up GPU memory after inference clear_model_cache() + +@spaces.GPU +def da3_single_inference(image, model: str, progress=gr.Progress()): + if image is None: + return None, "❌ Please upload an image." + + try: + if isinstance(image, str): + np_image = cv2.imread(image) + elif hasattr(image, "save"): + np_image = np.array(image) + if len(np_image.shape) == 3 and np_image.shape[2] == 3: + np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) + else: + np_image = np.array(image) + if len(np_image.shape) == 3 and np_image.shape[2] == 3: + np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) + + if np_image is None: + raise gr.Error("Invalid image input.") + + key = model[4:] if model.startswith("da3_") else model + + progress(0.1, desc=f"Running {model}") + depth_vis, processed_rgb, info_text, _ = run_da3_inference(key, np_image) + progress(1.0, desc="Done") + return (processed_rgb, depth_vis), info_text + + finally: + clear_model_cache() + def get_example_images() -> List[str]: import re @@ -443,6 +587,7 @@ def get_example_images() -> List[str]: "assets/examples", "Depth-Anything/assets/examples", "Depth-Anything-V2/assets/examples", + "Depth-Anything-3-anysize/assets/examples", "Pixel-Perfect-Depth/assets/examples", ]: ex_path = os.path.join(os.path.dirname(__file__), ex_dir) @@ -474,8 +619,19 @@ def create_app(): default2 = next((value for _, value in model_choices if value.startswith('v2_') and value != default1), model_choices[min(1, len(model_choices) - 1)][1]) example_images = get_example_images() + da3_choices = [(cfg['display_name'], f"da3_{key}") for key, cfg in DA3_MODEL_SOURCES.items()] + if not da3_choices: + raise ValueError("Depth Anything v3 models are not configured.") + da3_default = next((value for name, value in da3_choices if "Large" in name), da3_choices[0][1]) + + blocks_kwargs = {"title": "Depth Estimation Comparison"} + try: + if "theme" in inspect.signature(gr.Blocks.__init__).parameters and hasattr(gr, "themes"): + blocks_kwargs["theme"] = gr.themes.Soft() + except (ValueError, TypeError): + pass - with gr.Blocks(title="Depth Estimation Comparison", theme=gr.themes.Soft()) as app: + with gr.Blocks(**blocks_kwargs) as app: gr.Markdown(""" # Depth Estimation Comparison Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider. @@ -539,6 +695,7 @@ def create_app(): **References:** - **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything) - **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) + - **v3**: [Depth Anything v3](https://github.com/ByteDance-Seed/Depth-Anything-3) & [Depth-Anything-3-anysize](https://github.com/shriarul5273/Depth-Anything-3-anysize) - **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) **Note**: This app uses ZeroGPU for efficient GPU resource management. Models are loaded on-demand and GPU memory is automatically cleaned up after each inference. diff --git a/app_local.py b/app_local.py index 574564cf83376700e819f6a61b4af18a63af6561..444af91a886c704efb2427b8b1a0b6c3b8e08291 100644 --- a/app_local.py +++ b/app_local.py @@ -5,11 +5,14 @@ Compare Depth Anything models (v1 and v2) and Pixel-Perfect Depth side-by-side o Inspired by the Stereo Matching Methods Comparison Demo. """ +from __future__ import annotations + import os import sys import logging import tempfile import shutil +import inspect from pathlib import Path from typing import Optional, Tuple, Dict, List import numpy as np @@ -18,10 +21,12 @@ import gradio as gr from huggingface_hub import hf_hub_download import open3d as o3d import trimesh +from PIL import Image # Import v1 and v2 model code sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-V2")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-3-anysize", "src")) # v1 imports from depth_anything.dpt import DepthAnything as DepthAnythingV1 @@ -35,6 +40,10 @@ from depth_anything_v2.dpt import DepthAnythingV2 import matplotlib +# Depth Anything v3 imports +from depth_anything_3.api import DepthAnything3 +from depth_anything_3.utils.visualize import visualize_depth + # Pixel-Perfect Depth imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth")) from ppd.utils.set_seed import set_seed @@ -48,6 +57,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %( # Device selection DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +TORCH_DEVICE = torch.device(DEVICE) # Model configs V1_MODEL_CONFIGS = { @@ -83,9 +93,41 @@ V2_MODEL_CONFIGS = { } } +DA3_MODEL_SOURCES = { + "nested_giant_large": { + "display_name": "Depth Anything v3 Nested Giant Large", + "repo_id": "depth-anything/DA3NESTED-GIANT-LARGE", + }, + "giant": { + "display_name": "Depth Anything v3 Giant", + "repo_id": "depth-anything/DA3-GIANT", + }, + "large": { + "display_name": "Depth Anything v3 Large", + "repo_id": "depth-anything/DA3-LARGE", + }, + "base": { + "display_name": "Depth Anything v3 Base", + "repo_id": "depth-anything/DA3-BASE", + }, + "small": { + "display_name": "Depth Anything v3 Small", + "repo_id": "depth-anything/DA3-SMALL", + }, + "metric_large": { + "display_name": "Depth Anything v3 Metric Large", + "repo_id": "depth-anything/DA3METRIC-LARGE", + }, + "mono_large": { + "display_name": "Depth Anything v3 Mono Large", + "repo_id": "depth-anything/DA3MONO-LARGE", + }, +} + # Model cache _v1_models = {} _v2_models = {} +_da3_models: Dict[str, DepthAnything3] = {} # v1 transform v1_transform = Compose([ @@ -146,6 +188,91 @@ def load_v2_model(key: str): _v2_models[key] = model return model + +def load_da3_model(key: str) -> DepthAnything3: + if key in _da3_models: + return _da3_models[key] + repo_id = DA3_MODEL_SOURCES[key]["repo_id"] + model = DepthAnything3.from_pretrained(repo_id) + model = model.to(device=TORCH_DEVICE) + model.eval() + _da3_models[key] = model + return model + + +def _prep_da3_image(image: np.ndarray) -> np.ndarray: + if image.ndim == 2: + image = np.stack([image] * 3, axis=-1) + if image.dtype != np.uint8: + image = np.clip(image, 0, 255).astype(np.uint8) + return image + + +def run_da3_inference(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, str, str]: + model = load_da3_model(model_key) + if image.ndim == 2: + rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) + else: + rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + rgb = _prep_da3_image(rgb) + prediction = model.inference( + image=[Image.fromarray(rgb)], + process_res=None, + process_res_method="keep", + ) + depth_map = prediction.depth[0] + depth_vis = visualize_depth(depth_map, cmap="Spectral") + processed_rgb = ( + prediction.processed_images[0] + if getattr(prediction, "processed_images", None) is not None + else rgb + ) + processed_rgb = np.clip(processed_rgb, 0, 255).astype(np.uint8) + target_h, target_w = image.shape[:2] + if depth_vis.shape[:2] != (target_h, target_w): + depth_vis = cv2.resize(depth_vis, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + if processed_rgb.shape[:2] != (target_h, target_w): + processed_rgb = cv2.resize(processed_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR) + label = DA3_MODEL_SOURCES[model_key]["display_name"] + info_lines = [ + f"**Model:** `{label}`", + f"**Repo:** `{DA3_MODEL_SOURCES[model_key]['repo_id']}`", + f"**Device:** `{str(TORCH_DEVICE)}`", + f"**Depth shape:** `{tuple(prediction.depth.shape)}`", + ] + if getattr(prediction, "extrinsics", None) is not None: + info_lines.append(f"**Extrinsics shape:** `{prediction.extrinsics.shape}`") + if getattr(prediction, "intrinsics", None) is not None: + info_lines.append(f"**Intrinsics shape:** `{prediction.intrinsics.shape}`") + info_text = "\n".join(info_lines) + return depth_vis, processed_rgb, info_text, label + + +def da3_single_inference(image, model: str, progress=gr.Progress()): + if image is None: + return None, "❌ Please upload an image." + + if isinstance(image, str): + np_image = cv2.imread(image) + elif hasattr(image, "save"): + np_image = np.array(image) + if len(np_image.shape) == 3 and np_image.shape[2] == 3: + np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) + else: + np_image = np.array(image) + if len(np_image.shape) == 3 and np_image.shape[2] == 3: + np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR) + + if np_image is None: + raise gr.Error("Invalid image input.") + + key = model[4:] if model.startswith("da3_") else model + + progress(0.1, desc=f"Running {model}") + depth_vis, processed_rgb, info_text, label = run_da3_inference(key, np_image) + progress(1.0, desc="Done") + return (processed_rgb, depth_vis), info_text + def predict_v1(model, image: np.ndarray) -> np.ndarray: h, w = image.shape[:2] image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0 @@ -171,8 +298,6 @@ def colorize_depth(depth: np.ndarray) -> np.ndarray: # Pixel-Perfect Depth setup ------------------------------------------------- set_seed(666) - -TORCH_DEVICE = torch.device(DEVICE) PPD_DEFAULT_STEPS = 20 PPD_TEMP_ROOT = Path(tempfile.gettempdir()) / "ppd" @@ -308,6 +433,8 @@ def get_model_choices() -> List[Tuple[str, str]]: choices.append((v['display_name'], f'v1_{k}')) for k, v in V2_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v2_{k}')) + for k, v in DA3_MODEL_SOURCES.items(): + choices.append((v['display_name'], f'da3_{k}')) choices.append(("Pixel-Perfect Depth", "ppd")) return choices @@ -322,6 +449,10 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]: model = load_v2_model(key) depth = predict_v2(model, image) label = V2_MODEL_CONFIGS[key]['display_name'] + elif model_key.startswith('da3_'): + key = model_key[4:] + depth_vis, _, _, label = run_da3_inference(key, image) + return depth_vis, label elif model_key == 'ppd': slider_data, _, _ = pixel_perfect_depth_inference( image, @@ -449,6 +580,7 @@ def get_example_images() -> List[str]: "assets/examples", "Depth-Anything/assets/examples", "Depth-Anything-V2/assets/examples", + "Depth-Anything-3-anysize/assets/examples", "Pixel-Perfect-Depth/assets/examples", ]: ex_path = os.path.join(os.path.dirname(__file__), ex_dir) @@ -476,8 +608,19 @@ def create_app(): model_choices = get_model_choices() default1 = model_choices[0][1] default2 = model_choices[1][1] + da3_choices = [(cfg['display_name'], f"da3_{key}") for key, cfg in DA3_MODEL_SOURCES.items()] + if not da3_choices: + raise ValueError("Depth Anything v3 models are not configured.") + da3_default = da3_choices[2][1] if len(da3_choices) > 2 else da3_choices[0][1] example_images = get_example_images() - with gr.Blocks(title="Depth Anything v1 vs v2 Comparison", theme=gr.themes.Soft()) as app: + blocks_kwargs = {"title": "Depth Anything v1 vs v2 Comparison"} + try: + if "theme" in inspect.signature(gr.Blocks.__init__).parameters and hasattr(gr, "themes"): + # Use theme only when the installed gradio version accepts it. + blocks_kwargs["theme"] = gr.themes.Soft() + except (ValueError, TypeError): + pass + with gr.Blocks(**blocks_kwargs) as app: gr.Markdown(""" # Depth Estimation Comparison Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider. @@ -530,6 +673,7 @@ def create_app(): --- - **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything) - **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) + - **v3**: [Depth Anything v3](https://github.com/ByteDance-Seed/Depth-Anything-3) & [Depth-Anything-3-anysize](https://github.com/shriarul5273/Depth-Anything-3-anysize) - **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) """) return app diff --git a/requirements.txt b/requirements.txt index 4f97e9328543fb854492fd94d81039ecc620f543..0582f8557515a6efa836fdd753061e9161c0aad6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ open3d scikit-learn git+https://github.com/EasternJournalist/utils3d.git@c5daf6f6c244d251f252102d09e9b7bcef791a38 click # ==8.1.7 -trimesh # ==4.5.1 \ No newline at end of file +trimesh # ==4.5.1 +-e Depth-Anything-3-anysize/.[all] \ No newline at end of file